From c776e4eefb97f9a26e57292f26b1d0550178571e Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Sat, 23 Aug 2025 10:16:16 -0500 Subject: [PATCH 1/9] First implementation of Logical plan This is a first pass on logical plans. The idea is that the DBSP compiler will have an easier time operating on a logical plan, that exposes linear algebra operators, than on SQL expr. To keep this simple, we only support filters, aggregates and projections for now, and will add more later as we agree on the core of the implementation. To make sure that the implementations is reasonable, I tried my best to generate a couple of logical plans using Datafusion and seeing if we were generating something similar. Our plans are not the same as Datafusion's, though. There are two important differences: * SQLite is weird, and it allows columns that are not part of the group by statement to appear in aggregated statements. For example: select a, count(b) from table group by c; <== that "a" is usually not permitted and datafusion will reject it. SQLite will be happy to accept it * Datafusion will not generate a projection on queries like this: select sum(hex(a)) from table, and just keep the complex expression hex(a) inside the aggregation. For DBSP to work well, we'll need an explicit aggregation there. Because there are no users yet, I am marking this as [cfg(test)], but I wanted to put this out there ASAP. --- core/translate/logical.rs | 3076 +++++++++++++++++++++++++++++++++++++ core/translate/mod.rs | 2 + 2 files changed, 3078 insertions(+) create mode 100644 core/translate/logical.rs diff --git a/core/translate/logical.rs b/core/translate/logical.rs new file mode 100644 index 000000000..df8bdd13a --- /dev/null +++ b/core/translate/logical.rs @@ -0,0 +1,3076 @@ +//! Logical plan representation for SQL queries +//! +//! This module provides a platform-independent intermediate representation +//! for SQL queries. The logical plan is a DAG (Directed Acyclic Graph) that +//! supports CTEs and can be used for query optimization before being compiled +//! to an execution plan (e.g., DBSP circuits). +//! +//! The main entry point is `LogicalPlanBuilder` which constructs logical plans +//! from SQL AST nodes. +use crate::function::AggFunc; +use crate::schema::{Schema, Type}; +use crate::types::Value; +use crate::{LimboError, Result}; +use std::collections::HashMap; +use std::fmt::{self, Display, Formatter}; +use std::sync::Arc; +use turso_parser::ast; + +/// Result type for preprocessing aggregate expressions +type PreprocessAggregateResult = ( + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec<(String, Type)>, // pre_projection_schema + Vec, // modified_aggr_exprs +); + +/// Schema information for logical plan nodes +#[derive(Debug, Clone, PartialEq)] +pub struct LogicalSchema { + /// Column names and types + pub columns: Vec<(String, Type)>, +} +/// A reference to a schema that can be shared between nodes +pub type SchemaRef = Arc; + +impl LogicalSchema { + pub fn new(columns: Vec<(String, Type)>) -> Self { + Self { columns } + } + + pub fn empty() -> Self { + Self { + columns: Vec::new(), + } + } + + pub fn column_count(&self) -> usize { + self.columns.len() + } + + pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { + self.columns + .iter() + .position(|(n, _)| n == name) + .map(|idx| (idx, &self.columns[idx].1)) + } +} + +/// Logical representation of a SQL query plan +#[derive(Debug, Clone, PartialEq)] +pub enum LogicalPlan { + /// Projection - SELECT expressions + Projection(Projection), + /// Filter - WHERE/HAVING clause + Filter(Filter), + /// Aggregate - GROUP BY with aggregate functions + Aggregate(Aggregate), + // TODO: Join - combining two relations (not yet implemented) + // Join(Join), + /// Sort - ORDER BY clause + Sort(Sort), + /// Limit - LIMIT/OFFSET clause + Limit(Limit), + /// Table scan - reading from a base table + TableScan(TableScan), + /// Union - UNION/UNION ALL/INTERSECT/EXCEPT + Union(Union), + /// Distinct - remove duplicates + Distinct(Distinct), + /// Empty relation - no rows + EmptyRelation(EmptyRelation), + /// Values - literal rows (VALUES clause) + Values(Values), + /// CTE support - WITH clause + WithCTE(WithCTE), + /// Reference to a CTE + CTERef(CTERef), +} + +impl LogicalPlan { + /// Get the schema of this plan node + pub fn schema(&self) -> &SchemaRef { + match self { + LogicalPlan::Projection(p) => &p.schema, + LogicalPlan::Filter(f) => f.input.schema(), + LogicalPlan::Aggregate(a) => &a.schema, + // LogicalPlan::Join(j) => &j.schema, + LogicalPlan::Sort(s) => s.input.schema(), + LogicalPlan::Limit(l) => l.input.schema(), + LogicalPlan::TableScan(t) => &t.schema, + LogicalPlan::Union(u) => &u.schema, + LogicalPlan::Distinct(d) => d.input.schema(), + LogicalPlan::EmptyRelation(e) => &e.schema, + LogicalPlan::Values(v) => &v.schema, + LogicalPlan::WithCTE(w) => w.body.schema(), + LogicalPlan::CTERef(c) => &c.schema, + } + } +} + +/// Projection operator - SELECT expressions +#[derive(Debug, Clone, PartialEq)] +pub struct Projection { + pub input: Arc, + pub exprs: Vec, + pub schema: SchemaRef, +} + +/// Filter operator - WHERE/HAVING predicates +#[derive(Debug, Clone, PartialEq)] +pub struct Filter { + pub input: Arc, + pub predicate: LogicalExpr, +} + +/// Aggregate operator - GROUP BY with aggregations +#[derive(Debug, Clone, PartialEq)] +pub struct Aggregate { + pub input: Arc, + pub group_expr: Vec, + pub aggr_expr: Vec, + pub schema: SchemaRef, +} + +// TODO: Join operator (not yet implemented) +// #[derive(Debug, Clone, PartialEq)] +// pub struct Join { +// pub left: Arc, +// pub right: Arc, +// pub join_type: JoinType, +// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions +// pub filter: Option, // Additional filter conditions +// pub schema: SchemaRef, +// } + +// TODO: Types of joins (not yet implemented) +// #[derive(Debug, Clone, Copy, PartialEq)] +// pub enum JoinType { +// Inner, +// Left, +// Right, +// Full, +// Cross, +// } + +/// Sort operator - ORDER BY +#[derive(Debug, Clone, PartialEq)] +pub struct Sort { + pub input: Arc, + pub exprs: Vec, +} + +/// Sort expression with direction +#[derive(Debug, Clone, PartialEq)] +pub struct SortExpr { + pub expr: LogicalExpr, + pub asc: bool, + pub nulls_first: bool, +} + +/// Limit operator - LIMIT/OFFSET +#[derive(Debug, Clone, PartialEq)] +pub struct Limit { + pub input: Arc, + pub skip: Option, + pub fetch: Option, +} + +/// Table scan operator +#[derive(Debug, Clone, PartialEq)] +pub struct TableScan { + pub table_name: String, + pub schema: SchemaRef, + pub projection: Option>, // Column indices to project +} + +/// Union operator +#[derive(Debug, Clone, PartialEq)] +pub struct Union { + pub inputs: Vec>, + pub all: bool, // true for UNION ALL, false for UNION + pub schema: SchemaRef, +} + +/// Distinct operator +#[derive(Debug, Clone, PartialEq)] +pub struct Distinct { + pub input: Arc, +} + +/// Empty relation - produces no rows +#[derive(Debug, Clone, PartialEq)] +pub struct EmptyRelation { + pub produce_one_row: bool, + pub schema: SchemaRef, +} + +/// Values operator - literal rows +#[derive(Debug, Clone, PartialEq)] +pub struct Values { + pub rows: Vec>, + pub schema: SchemaRef, +} + +/// WITH clause - CTEs +#[derive(Debug, Clone, PartialEq)] +pub struct WithCTE { + pub ctes: HashMap>, + pub body: Arc, +} + +/// Reference to a CTE +#[derive(Debug, Clone, PartialEq)] +pub struct CTERef { + pub name: String, + pub schema: SchemaRef, +} + +/// Logical expression representation +#[derive(Debug, Clone, PartialEq)] +pub enum LogicalExpr { + /// Column reference + Column(Column), + /// Literal value + Literal(Value), + /// Binary expression + BinaryExpr { + left: Box, + op: BinaryOperator, + right: Box, + }, + /// Unary expression + UnaryExpr { + op: UnaryOperator, + expr: Box, + }, + /// Aggregate function + AggregateFunction { + fun: AggregateFunction, + args: Vec, + distinct: bool, + }, + /// Scalar function call + ScalarFunction { fun: String, args: Vec }, + /// CASE expression + Case { + expr: Option>, + when_then: Vec<(LogicalExpr, LogicalExpr)>, + else_expr: Option>, + }, + /// IN list + InList { + expr: Box, + list: Vec, + negated: bool, + }, + /// IN subquery + InSubquery { + expr: Box, + subquery: Arc, + negated: bool, + }, + /// EXISTS subquery + Exists { + subquery: Arc, + negated: bool, + }, + /// Scalar subquery + ScalarSubquery(Arc), + /// Alias for an expression + Alias { + expr: Box, + alias: String, + }, + /// IS NULL / IS NOT NULL + IsNull { + expr: Box, + negated: bool, + }, + /// BETWEEN + Between { + expr: Box, + low: Box, + high: Box, + negated: bool, + }, + /// LIKE pattern matching + Like { + expr: Box, + pattern: Box, + escape: Option, + negated: bool, + }, +} + +/// Column reference +#[derive(Debug, Clone, PartialEq)] +pub struct Column { + pub name: String, + pub table: Option, +} + +impl Column { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + table: None, + } + } + + pub fn with_table(name: impl Into, table: impl Into) -> Self { + Self { + name: name.into(), + table: Some(table.into()), + } + } +} + +impl Display for Column { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match &self.table { + Some(t) => write!(f, "{}.{}", t, self.name), + None => write!(f, "{}", self.name), + } + } +} + +/// Type alias for binary operators +pub type BinaryOperator = ast::Operator; + +/// Type alias for unary operators +pub type UnaryOperator = ast::UnaryOperator; + +/// Type alias for aggregate functions +pub type AggregateFunction = AggFunc; + +/// Compiler from AST to LogicalPlan +pub struct LogicalPlanBuilder<'a> { + schema: &'a Schema, + ctes: HashMap>, +} + +impl<'a> LogicalPlanBuilder<'a> { + pub fn new(schema: &'a Schema) -> Self { + Self { + schema, + ctes: HashMap::new(), + } + } + + /// Main entry point: compile a statement to a logical plan + pub fn build_statement(&mut self, stmt: &ast::Stmt) -> Result { + match stmt { + ast::Stmt::Select(select) => self.build_select(select), + _ => Err(LimboError::ParseError( + "Only SELECT statements are currently supported in logical plans".to_string(), + )), + } + } + + // Convert Name to String + fn name_to_string(name: &ast::Name) -> String { + match name { + ast::Name::Ident(s) | ast::Name::Quoted(s) => s.clone(), + } + } + + // Build a SELECT statement + // Build a logical plan from a SELECT statement + fn build_select(&mut self, select: &ast::Select) -> Result { + // Handle WITH clause if present + if let Some(with) = &select.with { + return self.build_with_cte(with, select); + } + + // Build the main query body + let order_by = &select.order_by; + let limit = &select.limit; + self.build_select_body(&select.body, order_by, limit) + } + + // Build WITH CTE + fn build_with_cte(&mut self, with: &ast::With, select: &ast::Select) -> Result { + let mut cte_plans = HashMap::new(); + + // Build each CTE + for cte in &with.ctes { + let cte_plan = self.build_select(&cte.select)?; + let cte_name = Self::name_to_string(&cte.tbl_name); + cte_plans.insert(cte_name.clone(), Arc::new(cte_plan)); + self.ctes + .insert(cte_name.clone(), cte_plans[&cte_name].clone()); + } + + // Build the main body with CTEs available + let order_by = &select.order_by; + let limit = &select.limit; + let body = self.build_select_body(&select.body, order_by, limit)?; + + // Clear CTEs from builder context + for cte in &with.ctes { + self.ctes.remove(&Self::name_to_string(&cte.tbl_name)); + } + + Ok(LogicalPlan::WithCTE(WithCTE { + ctes: cte_plans, + body: Arc::new(body), + })) + } + + // Build SELECT body + fn build_select_body( + &mut self, + body: &ast::SelectBody, + order_by: &[ast::SortedColumn], + limit: &Option, + ) -> Result { + let mut plan = self.build_one_select(&body.select)?; + + // Handle compound operators (UNION, INTERSECT, EXCEPT) + if !body.compounds.is_empty() { + for compound in &body.compounds { + let right = self.build_one_select(&compound.select)?; + plan = Self::build_compound(plan, right, &compound.operator)?; + } + } + + // Apply ORDER BY + if !order_by.is_empty() { + plan = self.build_sort(plan, order_by)?; + } + + // Apply LIMIT + if let Some(limit) = limit { + plan = Self::build_limit(plan, limit)?; + } + + Ok(plan) + } + + // Build a single SELECT (without compounds) + fn build_one_select(&mut self, select: &ast::OneSelect) -> Result { + match select { + ast::OneSelect::Select { + distinctness, + columns, + from, + where_clause, + group_by, + window_clause: _, + } => { + // Start with FROM clause + let mut plan = if let Some(from) = from { + self.build_from(from)? + } else { + // No FROM clause - single row + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(LogicalSchema::empty()), + }) + }; + + // Apply WHERE + if let Some(where_expr) = where_clause { + let predicate = self.build_expr(where_expr, plan.schema())?; + plan = LogicalPlan::Filter(Filter { + input: Arc::new(plan), + predicate, + }); + } + + // Apply GROUP BY and aggregations + if let Some(group_by) = group_by { + plan = self.build_aggregate(plan, group_by, columns)?; + } else if Self::has_aggregates(columns) { + // Aggregation without GROUP BY + plan = self.build_aggregate_no_group(plan, columns)?; + } else { + // Regular projection + plan = self.build_projection(plan, columns)?; + } + + // Apply HAVING (part of GROUP BY) + if let Some(ref group_by) = group_by { + if let Some(ref having_expr) = group_by.having { + let predicate = self.build_expr(having_expr, plan.schema())?; + plan = LogicalPlan::Filter(Filter { + input: Arc::new(plan), + predicate, + }); + } + } + + // Apply DISTINCT + if distinctness.is_some() { + plan = LogicalPlan::Distinct(Distinct { + input: Arc::new(plan), + }); + } + + Ok(plan) + } + ast::OneSelect::Values(values) => self.build_values(values), + } + } + + // Build FROM clause + fn build_from(&mut self, from: &ast::FromClause) -> Result { + let mut plan = { self.build_select_table(&from.select)? }; + + // Handle JOINs + if !from.joins.is_empty() { + for join in &from.joins { + let right = self.build_select_table(&join.table)?; + plan = self.build_join(plan, right, &join.operator, &join.constraint)?; + } + } + + Ok(plan) + } + + // Build a table reference + fn build_select_table(&mut self, table: &ast::SelectTable) -> Result { + match table { + ast::SelectTable::Table(name, _alias, _indexed) => { + let table_name = Self::name_to_string(&name.name); + // Check if it's a CTE reference + if let Some(cte_plan) = self.ctes.get(&table_name) { + return Ok(LogicalPlan::CTERef(CTERef { + name: table_name.clone(), + schema: cte_plan.schema().clone(), + })); + } + + // Regular table scan + let table_schema = self.get_table_schema(&table_name)?; + Ok(LogicalPlan::TableScan(TableScan { + table_name, + schema: table_schema, + projection: None, + })) + } + ast::SelectTable::Select(subquery, _alias) => self.build_select(subquery), + ast::SelectTable::TableCall(_, _, _) => Err(LimboError::ParseError( + "Table-valued functions are not supported in logical plans".to_string(), + )), + ast::SelectTable::Sub(_, _) => Err(LimboError::ParseError( + "Subquery in FROM clause not yet supported".to_string(), + )), + } + } + + // Build JOIN + fn build_join( + &mut self, + _left: LogicalPlan, + _right: LogicalPlan, + _op: &ast::JoinOperator, + _constraint: &Option, + ) -> Result { + Err(LimboError::ParseError( + "JOINs are not yet supported in logical plans".to_string(), + )) + } + + // Build projection + fn build_projection( + &mut self, + input: LogicalPlan, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + let mut proj_exprs = Vec::new(); + let mut schema_columns = Vec::new(); + + for col in columns { + match col { + ast::ResultColumn::Expr(expr, alias) => { + let logical_expr = self.build_expr(expr, input_schema)?; + let col_name = match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(expr), + }; + let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; + + schema_columns.push((col_name.clone(), col_type)); + + if let Some(as_alias) = alias { + let alias_name = match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }; + proj_exprs.push(LogicalExpr::Alias { + expr: Box::new(logical_expr), + alias: alias_name, + }); + } else { + proj_exprs.push(logical_expr); + } + } + ast::ResultColumn::Star => { + // Expand * to all columns + for (name, typ) in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); + schema_columns.push((name.clone(), *typ)); + } + } + ast::ResultColumn::TableStar(table) => { + // Expand table.* to all columns from that table + let table_name = Self::name_to_string(table); + for (name, typ) in &input_schema.columns { + // Simple check - would need proper table tracking in real implementation + proj_exprs.push(LogicalExpr::Column(Column::with_table( + name.clone(), + table_name.clone(), + ))); + schema_columns.push((name.clone(), *typ)); + } + } + } + } + + Ok(LogicalPlan::Projection(Projection { + input: Arc::new(input), + exprs: proj_exprs, + schema: Arc::new(LogicalSchema::new(schema_columns)), + })) + } + + // Helper function to preprocess aggregate expressions that contain complex arguments + // Returns: (needs_pre_projection, pre_projection_exprs, pre_projection_schema, modified_aggr_exprs) + // + // This will be used in expressions like select sum(hex(a + 2)) from tbl => hex(a + 2) is a + // pre-projection. + // + // Another alternative is to always generate a projection together with an aggregation, and + // just have "a" be the identity projection if we don't have a complex case. But that's quite + // wasteful. + fn preprocess_aggregate_expressions( + aggr_exprs: &[LogicalExpr], + group_exprs: &[LogicalExpr], + input_schema: &SchemaRef, + ) -> Result { + let mut needs_pre_projection = false; + let mut pre_projection_exprs = Vec::new(); + let mut pre_projection_schema = Vec::new(); + let mut modified_aggr_exprs = Vec::new(); + let mut projected_col_counter = 0; + + // First, add all group by expressions to the pre-projection + for expr in group_exprs { + if let LogicalExpr::Column(col) = expr { + pre_projection_exprs.push(expr.clone()); + let col_type = Self::infer_expr_type(expr, input_schema)?; + pre_projection_schema.push((col.name.clone(), col_type)); + } else { + // Complex group by expression - project it + needs_pre_projection = true; + let proj_col_name = format!("__group_proj_{projected_col_counter}"); + projected_col_counter += 1; + pre_projection_exprs.push(expr.clone()); + let col_type = Self::infer_expr_type(expr, input_schema)?; + pre_projection_schema.push((proj_col_name.clone(), col_type)); + } + } + + // Check each aggregate expression + for agg_expr in aggr_exprs { + if let LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } = agg_expr + { + let mut modified_args = Vec::new(); + for arg in args { + // Check if the argument is a simple column reference or a complex expression + match arg { + LogicalExpr::Column(_) => { + // Simple column - just use it + modified_args.push(arg.clone()); + // Make sure the column is in the pre-projection + if !pre_projection_exprs.iter().any(|e| e == arg) { + pre_projection_exprs.push(arg.clone()); + let col_type = Self::infer_expr_type(arg, input_schema)?; + if let LogicalExpr::Column(col) = arg { + pre_projection_schema.push((col.name.clone(), col_type)); + } + } + } + _ => { + // Complex expression - we need to project it first + needs_pre_projection = true; + let proj_col_name = format!("__agg_arg_proj_{projected_col_counter}"); + projected_col_counter += 1; + + // Add the expression to the pre-projection + pre_projection_exprs.push(arg.clone()); + let col_type = Self::infer_expr_type(arg, input_schema)?; + pre_projection_schema.push((proj_col_name.clone(), col_type)); + + // In the aggregate, reference the projected column + modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); + } + } + } + + // Create the modified aggregate expression + modified_aggr_exprs.push(LogicalExpr::AggregateFunction { + fun: fun.clone(), + args: modified_args, + distinct: *distinct, + }); + } else { + modified_aggr_exprs.push(agg_expr.clone()); + } + } + + Ok(( + needs_pre_projection, + pre_projection_exprs, + pre_projection_schema, + modified_aggr_exprs, + )) + } + + // Build aggregate with GROUP BY + fn build_aggregate( + &mut self, + input: LogicalPlan, + group_by: &ast::GroupBy, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + + // Build grouping expressions + let mut group_exprs = Vec::new(); + for expr in &group_by.exprs { + group_exprs.push(self.build_expr(expr, input_schema)?); + } + + // Use the unified aggregate builder + self.build_aggregate_internal(input, group_exprs, columns) + } + + // Build aggregate without GROUP BY + fn build_aggregate_no_group( + &mut self, + input: LogicalPlan, + columns: &[ast::ResultColumn], + ) -> Result { + // Use the unified aggregate builder with empty group expressions + self.build_aggregate_internal(input, vec![], columns) + } + + // Unified internal aggregate builder that handles both GROUP BY and non-GROUP BY cases + fn build_aggregate_internal( + &mut self, + input: LogicalPlan, + group_exprs: Vec, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + let has_group_by = !group_exprs.is_empty(); + + // Build aggregate expressions and projection expressions + let mut aggr_exprs = Vec::new(); + let mut projection_exprs = Vec::new(); + let mut aggregate_schema_columns = Vec::new(); + + // First, add GROUP BY columns to the aggregate output schema + // These are always part of the aggregate operator's output + for group_expr in &group_exprs { + let col_name = match group_expr { + LogicalExpr::Column(col) => col.name.clone(), + _ => { + // For complex GROUP BY expressions, generate a name + format!("__group_{}", aggregate_schema_columns.len()) + } + }; + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push((col_name, col_type)); + } + + // Track aggregates we've already seen to avoid duplicates + let mut aggregate_map: std::collections::HashMap = + std::collections::HashMap::new(); + + for col in columns { + match col { + ast::ResultColumn::Expr(expr, alias) => { + let logical_expr = self.build_expr(expr, input_schema)?; + + // Determine the column name for this expression + let col_name = match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(expr), + }; + + // Check if the TOP-LEVEL expression is an aggregate + // We only care about immediate aggregates, not nested ones + if Self::is_aggregate_expr(&logical_expr) { + // Pure aggregate function - check if we've seen it before + let agg_key = format!("{logical_expr:?}"); + + let agg_col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) + { + // Reuse existing aggregate + existing_name.clone() + } else { + // New aggregate - add it + let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; + aggregate_schema_columns.push((col_name.clone(), col_type)); + aggr_exprs.push(logical_expr); + aggregate_map.insert(agg_key, col_name.clone()); + col_name.clone() + }; + + // In the projection, reference this aggregate by name + projection_exprs.push(LogicalExpr::Column(Column { + name: agg_col_name, + table: None, + })); + } else if Self::contains_aggregate(&logical_expr) { + // This is an expression that contains an aggregate somewhere + // (e.g., sum(a + 2) * 2) + // We need to extract aggregates and replace them with column references + let (processed_expr, extracted_aggs) = + Self::extract_and_replace_aggregates_with_dedup( + logical_expr, + &mut aggregate_map, + )?; + + // Add only new aggregates + for (agg_expr, agg_name) in extracted_aggs { + let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; + aggregate_schema_columns.push((agg_name, agg_type)); + aggr_exprs.push(agg_expr); + } + + // Add the processed expression (with column refs) to projection + projection_exprs.push(processed_expr); + } else { + // Non-aggregate expression - validation depends on GROUP BY presence + if has_group_by { + // With GROUP BY: only allow constants and grouped columns + // TODO: SQLite actually allows any column here and returns the value from + // the first row encountered in each group. We should support this in the + // future for full SQLite compatibility, but for now we're stricter to + // simplify the DBSP compilation. + if !Self::is_constant_expr(&logical_expr) + && !Self::is_valid_in_group_by(&logical_expr, &group_exprs) + { + return Err(LimboError::ParseError(format!( + "Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function" + ))); + } + } else { + // Without GROUP BY: only allow constant expressions + // TODO: SQLite allows any column here and returns a value from an + // arbitrary row. We should support this for full compatibility, + // but for now we're stricter to simplify DBSP compilation. + if !Self::is_constant_expr(&logical_expr) { + return Err(LimboError::ParseError(format!( + "Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY" + ))); + } + } + projection_exprs.push(logical_expr); + } + } + _ => { + let error_msg = if has_group_by { + "* not supported with GROUP BY".to_string() + } else { + "* not supported with aggregate functions".to_string() + }; + return Err(LimboError::ParseError(error_msg)); + } + } + } + + // Check if any aggregate functions have complex expressions as arguments + // If so, we need to insert a projection before the aggregate + let ( + needs_pre_projection, + pre_projection_exprs, + pre_projection_schema, + modified_aggr_exprs, + ) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?; + + // Build the final schema for the projection + let mut projection_schema_columns = Vec::new(); + for (i, expr) in projection_exprs.iter().enumerate() { + let col_name = if i < columns.len() { + match &columns[i] { + ast::ResultColumn::Expr(e, alias) => match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(e), + }, + _ => format!("col_{i}"), + } + } else { + format!("col_{i}") + }; + + // For type inference, we need the aggregate schema for column references + let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); + let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; + projection_schema_columns.push((col_name, col_type)); + } + + // Create the input plan (with pre-projection if needed) + let aggregate_input = if needs_pre_projection { + Arc::new(LogicalPlan::Projection(Projection { + input: Arc::new(input), + exprs: pre_projection_exprs, + schema: Arc::new(LogicalSchema::new(pre_projection_schema)), + })) + } else { + Arc::new(input) + }; + + // Use modified aggregate expressions if we inserted a pre-projection + let final_aggr_exprs = if needs_pre_projection { + modified_aggr_exprs + } else { + aggr_exprs + }; + + // Check if we need the outer projection + // We need a projection if: + // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) + // 2. We're selecting a different set of columns than what the aggregate outputs + // 3. Columns are renamed or reordered + let needs_outer_projection = { + // Check if any expression is more complex than a simple column reference + let has_complex_exprs = projection_exprs + .iter() + .any(|expr| !matches!(expr, LogicalExpr::Column(_))); + + if has_complex_exprs { + true + } else { + // All are simple columns - check if we're selecting exactly what the aggregate outputs + // The projection might be selecting a subset (e.g., only aggregates without group columns) + // or reordering columns, or using different names + + // For now, keep it simple: if schemas don't match exactly, we need projection + // This handles all cases: subset selection, reordering, renaming + projection_schema_columns != aggregate_schema_columns + } + }; + + // Create the aggregate node + let aggregate_plan = LogicalPlan::Aggregate(Aggregate { + input: aggregate_input, + group_expr: group_exprs, + aggr_expr: final_aggr_exprs, + schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)), + }); + + if needs_outer_projection { + Ok(LogicalPlan::Projection(Projection { + input: Arc::new(aggregate_plan), + exprs: projection_exprs, + schema: Arc::new(LogicalSchema::new(projection_schema_columns)), + })) + } else { + // No projection needed - the aggregate output is exactly what we want + Ok(aggregate_plan) + } + } + + /// Build VALUES clause + #[allow(clippy::vec_box)] + fn build_values(&mut self, values: &[Vec>]) -> Result { + if values.is_empty() { + return Err(LimboError::ParseError("Empty VALUES clause".to_string())); + } + + let mut rows = Vec::new(); + let first_row_len = values[0].len(); + + // Infer schema from first row + let mut schema_columns = Vec::new(); + for (i, _) in values[0].iter().enumerate() { + schema_columns.push((format!("column{}", i + 1), Type::Text)); + } + + for row in values { + if row.len() != first_row_len { + return Err(LimboError::ParseError( + "All rows in VALUES must have the same number of columns".to_string(), + )); + } + + let mut logical_row = Vec::new(); + for expr in row { + // VALUES doesn't have input schema + let empty_schema = Arc::new(LogicalSchema::empty()); + logical_row.push(self.build_expr(expr, &empty_schema)?); + } + rows.push(logical_row); + } + + Ok(LogicalPlan::Values(Values { + rows, + schema: Arc::new(LogicalSchema::new(schema_columns)), + })) + } + + // Build SORT + fn build_sort( + &mut self, + input: LogicalPlan, + exprs: &[ast::SortedColumn], + ) -> Result { + let input_schema = input.schema(); + let mut sort_exprs = Vec::new(); + + for sorted_col in exprs { + let expr = self.build_expr(&sorted_col.expr, input_schema)?; + sort_exprs.push(SortExpr { + expr, + asc: sorted_col.order != Some(ast::SortOrder::Desc), + nulls_first: sorted_col.nulls == Some(ast::NullsOrder::First), + }); + } + + Ok(LogicalPlan::Sort(Sort { + input: Arc::new(input), + exprs: sort_exprs, + })) + } + + // Build LIMIT + fn build_limit(input: LogicalPlan, limit: &ast::Limit) -> Result { + let fetch = match limit.expr.as_ref() { + ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::().ok(), + _ => { + return Err(LimboError::ParseError( + "LIMIT must be a literal integer".to_string(), + )) + } + }; + + let skip = if let Some(offset) = &limit.offset { + match offset.as_ref() { + ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::().ok(), + _ => { + return Err(LimboError::ParseError( + "OFFSET must be a literal integer".to_string(), + )) + } + } + } else { + None + }; + + Ok(LogicalPlan::Limit(Limit { + input: Arc::new(input), + skip, + fetch, + })) + } + + // Build compound operator (UNION, INTERSECT, EXCEPT) + fn build_compound( + left: LogicalPlan, + right: LogicalPlan, + op: &ast::CompoundOperator, + ) -> Result { + // Check schema compatibility + if left.schema().column_count() != right.schema().column_count() { + return Err(LimboError::ParseError( + "UNION/INTERSECT/EXCEPT requires same number of columns".to_string(), + )); + } + + let all = matches!(op, ast::CompoundOperator::UnionAll); + + match op { + ast::CompoundOperator::Union | ast::CompoundOperator::UnionAll => { + let schema = left.schema().clone(); + Ok(LogicalPlan::Union(Union { + inputs: vec![Arc::new(left), Arc::new(right)], + all, + schema, + })) + } + _ => Err(LimboError::ParseError( + "INTERSECT and EXCEPT not yet supported in logical plans".to_string(), + )), + } + } + + // Build expression from AST + fn build_expr(&mut self, expr: &ast::Expr, _schema: &SchemaRef) -> Result { + match expr { + ast::Expr::Id(name) => Ok(LogicalExpr::Column(Column::new(Self::name_to_string(name)))), + + ast::Expr::DoublyQualified(db, table, col) => { + Ok(LogicalExpr::Column(Column::with_table( + Self::name_to_string(col), + format!( + "{}.{}", + Self::name_to_string(db), + Self::name_to_string(table) + ), + ))) + } + + ast::Expr::Qualified(table, col) => Ok(LogicalExpr::Column(Column::with_table( + Self::name_to_string(col), + Self::name_to_string(table), + ))), + + ast::Expr::Literal(lit) => Ok(LogicalExpr::Literal(Self::build_literal(lit)?)), + + ast::Expr::Binary(lhs, op, rhs) => { + // Special case: IS NULL and IS NOT NULL + if matches!(op, ast::Operator::Is | ast::Operator::IsNot) { + if let ast::Expr::Literal(ast::Literal::Null) = rhs.as_ref() { + let expr = Box::new(self.build_expr(lhs, _schema)?); + return Ok(LogicalExpr::IsNull { + expr, + negated: matches!(op, ast::Operator::IsNot), + }); + } + } + + let left = Box::new(self.build_expr(lhs, _schema)?); + let right = Box::new(self.build_expr(rhs, _schema)?); + Ok(LogicalExpr::BinaryExpr { + left, + op: *op, + right, + }) + } + + ast::Expr::Unary(op, expr) => { + let inner = Box::new(self.build_expr(expr, _schema)?); + Ok(LogicalExpr::UnaryExpr { + op: *op, + expr: inner, + }) + } + + ast::Expr::FunctionCall { + name, + distinctness, + args, + filter_over, + .. + } => { + // Check for window functions (OVER clause) + if filter_over.over_clause.is_some() { + return Err(LimboError::ParseError( + "Unsupported expression type: window functions are not yet supported" + .to_string(), + )); + } + + let func_name = Self::name_to_string(name); + let arg_count = args.len(); + // Check if it's an aggregate function (considering argument count for min/max) + if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, arg_count) { + let distinct = distinctness.is_some(); + let arg_exprs = args + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::AggregateFunction { + fun: agg_fun, + args: arg_exprs, + distinct, + }) + } else { + // Regular scalar function + let arg_exprs = args + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::ScalarFunction { + fun: func_name, + args: arg_exprs, + }) + } + } + + ast::Expr::FunctionCallStar { name, .. } => { + // Handle COUNT(*) and similar + let func_name = Self::name_to_string(name); + // FunctionCallStar always has 0 args (it's the * form) + if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, 0) { + Ok(LogicalExpr::AggregateFunction { + fun: agg_fun, + args: vec![], + distinct: false, + }) + } else { + Err(LimboError::ParseError(format!( + "Function {func_name}(*) is not supported" + ))) + } + } + + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + let case_expr = if let Some(e) = base { + Some(Box::new(self.build_expr(e, _schema)?)) + } else { + None + }; + + let when_then_exprs = when_then_pairs + .iter() + .map(|(when, then)| { + Ok(( + self.build_expr(when, _schema)?, + self.build_expr(then, _schema)?, + )) + }) + .collect::>>()?; + + let else_result = if let Some(e) = else_expr { + Some(Box::new(self.build_expr(e, _schema)?)) + } else { + None + }; + + Ok(LogicalExpr::Case { + expr: case_expr, + when_then: when_then_exprs, + else_expr: else_result, + }) + } + + ast::Expr::InList { lhs, not, rhs } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let list = rhs + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::InList { + expr, + list, + negated: *not, + }) + } + + ast::Expr::InSelect { lhs, not, rhs } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let subquery = Arc::new(self.build_select(rhs)?); + Ok(LogicalExpr::InSubquery { + expr, + subquery, + negated: *not, + }) + } + + ast::Expr::Exists(select) => { + let subquery = Arc::new(self.build_select(select)?); + Ok(LogicalExpr::Exists { + subquery, + negated: false, + }) + } + + ast::Expr::Subquery(select) => { + let subquery = Arc::new(self.build_select(select)?); + Ok(LogicalExpr::ScalarSubquery(subquery)) + } + + ast::Expr::IsNull(lhs) => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + Ok(LogicalExpr::IsNull { + expr, + negated: false, + }) + } + + ast::Expr::NotNull(lhs) => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + Ok(LogicalExpr::IsNull { + expr, + negated: true, + }) + } + + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let low = Box::new(self.build_expr(start, _schema)?); + let high = Box::new(self.build_expr(end, _schema)?); + Ok(LogicalExpr::Between { + expr, + low, + high, + negated: *not, + }) + } + + ast::Expr::Like { + lhs, + not, + op: _, + rhs, + escape, + } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let pattern = Box::new(self.build_expr(rhs, _schema)?); + let escape_char = escape.as_ref().and_then(|e| { + if let ast::Expr::Literal(ast::Literal::String(s)) = e.as_ref() { + s.chars().next() + } else { + None + } + }); + Ok(LogicalExpr::Like { + expr, + pattern, + escape: escape_char, + negated: *not, + }) + } + + ast::Expr::Parenthesized(exprs) => { + // the assumption is that there is at least one parenthesis here. + // If this is not true, then I don't understand this code and can't be trusted. + assert!(!exprs.is_empty()); + // Multiple expressions in parentheses is unusual but handle it + // by building the first one (SQLite behavior) + self.build_expr(&exprs[0], _schema) + } + + _ => Err(LimboError::ParseError(format!( + "Unsupported expression type in logical plan: {expr:?}" + ))), + } + } + + /// Build literal value + fn build_literal(lit: &ast::Literal) -> Result { + match lit { + ast::Literal::Null => Ok(Value::Null), + ast::Literal::Keyword(k) if k.eq_ignore_ascii_case("true") => Ok(Value::Integer(1)), // SQLite uses int for bool + ast::Literal::Keyword(k) if k.eq_ignore_ascii_case("false") => Ok(Value::Integer(0)), // SQLite uses int for bool + ast::Literal::Keyword(k) => Ok(Value::Text(k.clone().into())), + ast::Literal::Numeric(s) => { + if let Ok(i) = s.parse::() { + Ok(Value::Integer(i)) + } else if let Ok(f) = s.parse::() { + Ok(Value::Float(f)) + } else { + Ok(Value::Text(s.clone().into())) + } + } + ast::Literal::String(s) => { + // Strip surrounding quotes from the SQL literal + // The parser includes quotes in the string value + let unquoted = if s.starts_with('\'') && s.ends_with('\'') && s.len() > 1 { + &s[1..s.len() - 1] + } else { + s.as_str() + }; + Ok(Value::Text(unquoted.to_string().into())) + } + ast::Literal::Blob(b) => Ok(Value::Blob(b.clone().into())), + ast::Literal::CurrentDate + | ast::Literal::CurrentTime + | ast::Literal::CurrentTimestamp => Err(LimboError::ParseError( + "Temporal literals not yet supported".to_string(), + )), + } + } + + /// Parse aggregate function name (considering argument count for min/max) + fn parse_aggregate_function(name: &str, arg_count: usize) -> Option { + match name.to_uppercase().as_str() { + "COUNT" => Some(AggFunc::Count), + "SUM" => Some(AggFunc::Sum), + "AVG" => Some(AggFunc::Avg), + // MIN and MAX are only aggregates with 1 argument + // With 2+ arguments, they're scalar functions + "MIN" if arg_count == 1 => Some(AggFunc::Min), + "MAX" if arg_count == 1 => Some(AggFunc::Max), + "GROUP_CONCAT" => Some(AggFunc::GroupConcat), + "STRING_AGG" => Some(AggFunc::StringAgg), + "TOTAL" => Some(AggFunc::Total), + _ => None, + } + } + + // Check if expression contains aggregates + fn has_aggregates(columns: &[ast::ResultColumn]) -> bool { + for col in columns { + if let ast::ResultColumn::Expr(expr, _) = col { + if Self::expr_has_aggregate(expr) { + return true; + } + } + } + false + } + + // Check if AST expression contains aggregates + fn expr_has_aggregate(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::FunctionCall { name, args, .. } => { + // Check if the function itself is an aggregate (considering arg count for min/max) + let arg_count = args.len(); + if Self::parse_aggregate_function(&Self::name_to_string(name), arg_count).is_some() + { + return true; + } + // Also check if any arguments contain aggregates (for nested functions like HEX(SUM(...))) + args.iter().any(|arg| Self::expr_has_aggregate(arg)) + } + ast::Expr::FunctionCallStar { name, .. } => { + // FunctionCallStar always has 0 args + Self::parse_aggregate_function(&Self::name_to_string(name), 0).is_some() + } + ast::Expr::Binary(lhs, _, rhs) => { + Self::expr_has_aggregate(lhs) || Self::expr_has_aggregate(rhs) + } + ast::Expr::Unary(_, e) => Self::expr_has_aggregate(e), + ast::Expr::Case { + when_then_pairs, + else_expr, + .. + } => { + when_then_pairs + .iter() + .any(|(w, t)| Self::expr_has_aggregate(w) || Self::expr_has_aggregate(t)) + || else_expr + .as_ref() + .is_some_and(|e| Self::expr_has_aggregate(e)) + } + ast::Expr::Parenthesized(exprs) => { + // Check if any parenthesized expression contains an aggregate + exprs.iter().any(|e| Self::expr_has_aggregate(e)) + } + _ => false, + } + } + + // Check if logical expression is an aggregate + fn is_aggregate_expr(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::AggregateFunction { .. } => true, + LogicalExpr::Alias { expr, .. } => Self::is_aggregate_expr(expr), + _ => false, + } + } + + // Check if logical expression contains an aggregate anywhere + fn contains_aggregate(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::AggregateFunction { .. } => true, + LogicalExpr::Alias { expr, .. } => Self::contains_aggregate(expr), + LogicalExpr::BinaryExpr { left, right, .. } => { + Self::contains_aggregate(left) || Self::contains_aggregate(right) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::contains_aggregate(expr), + LogicalExpr::ScalarFunction { args, .. } => args.iter().any(Self::contains_aggregate), + LogicalExpr::Case { + when_then, + else_expr, + .. + } => { + when_then + .iter() + .any(|(w, t)| Self::contains_aggregate(w) || Self::contains_aggregate(t)) + || else_expr + .as_ref() + .is_some_and(|e| Self::contains_aggregate(e)) + } + _ => false, + } + } + + // Check if an expression is a constant (contains only literals) + fn is_constant_expr(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::Literal(_) => true, + LogicalExpr::BinaryExpr { left, right, .. } => { + Self::is_constant_expr(left) && Self::is_constant_expr(right) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::is_constant_expr(expr), + LogicalExpr::ScalarFunction { args, .. } => args.iter().all(Self::is_constant_expr), + LogicalExpr::Alias { expr, .. } => Self::is_constant_expr(expr), + _ => false, + } + } + + // Check if an expression is valid in GROUP BY context + // An expression is valid if it's: + // 1. A constant literal + // 2. An aggregate function + // 3. A grouping column (or expression involving only grouping columns) + fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool { + match expr { + LogicalExpr::Literal(_) => true, // Constants are always valid + LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid + LogicalExpr::Column(col) => { + // Check if this column is in the GROUP BY + group_exprs.iter().any(|g| match g { + LogicalExpr::Column(gcol) => gcol.name == col.name, + _ => false, + }) + } + LogicalExpr::BinaryExpr { left, right, .. } => { + // Both sides must be valid + Self::is_valid_in_group_by(left, group_exprs) + && Self::is_valid_in_group_by(right, group_exprs) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs), + LogicalExpr::ScalarFunction { args, .. } => { + // All arguments must be valid + args.iter() + .all(|arg| Self::is_valid_in_group_by(arg, group_exprs)) + } + LogicalExpr::Alias { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs), + _ => false, // Other expressions are not valid + } + } + + // Extract aggregates from an expression and replace them with column references, with deduplication + // Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs + fn extract_and_replace_aggregates_with_dedup( + expr: LogicalExpr, + aggregate_map: &mut std::collections::HashMap, + ) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> { + let mut new_aggregates = Vec::new(); + let mut counter = aggregate_map.len(); + let new_expr = Self::replace_aggregates_with_columns_dedup( + expr, + &mut new_aggregates, + aggregate_map, + &mut counter, + )?; + Ok((new_expr, new_aggregates)) + } + + // Recursively replace aggregate functions with column references, with deduplication + fn replace_aggregates_with_columns_dedup( + expr: LogicalExpr, + new_aggregates: &mut Vec<(LogicalExpr, String)>, + aggregate_map: &mut std::collections::HashMap, + counter: &mut usize, + ) -> Result { + match expr { + LogicalExpr::AggregateFunction { .. } => { + // Found an aggregate - check if we've seen it before + let agg_key = format!("{expr:?}"); + + let col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) { + // Reuse existing aggregate + existing_name.clone() + } else { + // New aggregate + let col_name = format!("__agg_{}", *counter); + *counter += 1; + aggregate_map.insert(agg_key, col_name.clone()); + new_aggregates.push((expr, col_name.clone())); + col_name + }; + + Ok(LogicalExpr::Column(Column { + name: col_name, + table: None, + })) + } + LogicalExpr::BinaryExpr { left, op, right } => { + let new_left = Self::replace_aggregates_with_columns_dedup( + *left, + new_aggregates, + aggregate_map, + counter, + )?; + let new_right = Self::replace_aggregates_with_columns_dedup( + *right, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::BinaryExpr { + left: Box::new(new_left), + op, + right: Box::new(new_right), + }) + } + LogicalExpr::UnaryExpr { op, expr } => { + let new_expr = Self::replace_aggregates_with_columns_dedup( + *expr, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::UnaryExpr { + op, + expr: Box::new(new_expr), + }) + } + LogicalExpr::ScalarFunction { fun, args } => { + let mut new_args = Vec::new(); + for arg in args { + new_args.push(Self::replace_aggregates_with_columns_dedup( + arg, + new_aggregates, + aggregate_map, + counter, + )?); + } + Ok(LogicalExpr::ScalarFunction { + fun, + args: new_args, + }) + } + LogicalExpr::Case { + expr: case_expr, + when_then, + else_expr, + } => { + let new_case_expr = if let Some(e) = case_expr { + Some(Box::new(Self::replace_aggregates_with_columns_dedup( + *e, + new_aggregates, + aggregate_map, + counter, + )?)) + } else { + None + }; + + let mut new_when_then = Vec::new(); + for (when, then) in when_then { + let new_when = Self::replace_aggregates_with_columns_dedup( + when, + new_aggregates, + aggregate_map, + counter, + )?; + let new_then = Self::replace_aggregates_with_columns_dedup( + then, + new_aggregates, + aggregate_map, + counter, + )?; + new_when_then.push((new_when, new_then)); + } + + let new_else = if let Some(e) = else_expr { + Some(Box::new(Self::replace_aggregates_with_columns_dedup( + *e, + new_aggregates, + aggregate_map, + counter, + )?)) + } else { + None + }; + + Ok(LogicalExpr::Case { + expr: new_case_expr, + when_then: new_when_then, + else_expr: new_else, + }) + } + LogicalExpr::Alias { expr, alias } => { + let new_expr = Self::replace_aggregates_with_columns_dedup( + *expr, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::Alias { + expr: Box::new(new_expr), + alias, + }) + } + // Other expressions - keep as is + _ => Ok(expr), + } + } + + // Get column name from expression + fn expr_to_column_name(expr: &ast::Expr) -> String { + match expr { + ast::Expr::Id(name) => Self::name_to_string(name), + ast::Expr::Qualified(_, col) => Self::name_to_string(col), + ast::Expr::FunctionCall { name, .. } => Self::name_to_string(name), + ast::Expr::FunctionCallStar { name, .. } => { + format!("{}(*)", Self::name_to_string(name)) + } + _ => "expr".to_string(), + } + } + + // Get table schema + fn get_table_schema(&self, table_name: &str) -> Result { + // Look up table in schema + let table = self + .schema + .get_table(table_name) + .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + + let mut columns = Vec::new(); + for col in table.columns() { + if let Some(ref name) = col.name { + columns.push((name.clone(), col.ty)); + } + } + + Ok(Arc::new(LogicalSchema::new(columns))) + } + + // Infer expression type + fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { + match expr { + LogicalExpr::Column(col) => { + if let Some((_, typ)) = schema.find_column(&col.name) { + Ok(*typ) + } else { + Ok(Type::Text) + } + } + LogicalExpr::Literal(Value::Integer(_)) => Ok(Type::Integer), + LogicalExpr::Literal(Value::Float(_)) => Ok(Type::Real), + LogicalExpr::Literal(Value::Text(_)) => Ok(Type::Text), + LogicalExpr::Literal(Value::Null) => Ok(Type::Null), + LogicalExpr::Literal(Value::Blob(_)) => Ok(Type::Blob), + LogicalExpr::BinaryExpr { op, left, right } => { + match op { + ast::Operator::Add | ast::Operator::Subtract | ast::Operator::Multiply => { + // Infer types of operands to match SQLite/Numeric behavior + let left_type = Self::infer_expr_type(left, schema)?; + let right_type = Self::infer_expr_type(right, schema)?; + + // Integer op Integer = Integer (matching core/numeric/mod.rs behavior) + // Any operation with Real = Real + match (left_type, right_type) { + (Type::Integer, Type::Integer) => Ok(Type::Integer), + (Type::Integer, Type::Real) + | (Type::Real, Type::Integer) + | (Type::Real, Type::Real) => Ok(Type::Real), + (Type::Null, _) | (_, Type::Null) => Ok(Type::Null), + // For Text/Blob, SQLite coerces to numeric, defaulting to Real + _ => Ok(Type::Real), + } + } + ast::Operator::Divide => { + // Division always produces Real in SQLite + Ok(Type::Real) + } + ast::Operator::Modulus => { + // Modulus follows same rules as other arithmetic ops + let left_type = Self::infer_expr_type(left, schema)?; + let right_type = Self::infer_expr_type(right, schema)?; + match (left_type, right_type) { + (Type::Integer, Type::Integer) => Ok(Type::Integer), + _ => Ok(Type::Real), + } + } + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Less + | ast::Operator::LessEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::And + | ast::Operator::Or + | ast::Operator::Is + | ast::Operator::IsNot => Ok(Type::Integer), + ast::Operator::Concat => Ok(Type::Text), + _ => Ok(Type::Text), // Default for other operators + } + } + LogicalExpr::UnaryExpr { op, expr } => match op { + ast::UnaryOperator::Not => Ok(Type::Integer), + ast::UnaryOperator::Negative | ast::UnaryOperator::Positive => { + Self::infer_expr_type(expr, schema) + } + ast::UnaryOperator::BitwiseNot => Ok(Type::Integer), + }, + LogicalExpr::AggregateFunction { fun, .. } => match fun { + AggFunc::Count | AggFunc::Count0 => Ok(Type::Integer), + AggFunc::Sum | AggFunc::Avg | AggFunc::Total => Ok(Type::Real), + AggFunc::Min | AggFunc::Max => Ok(Type::Text), + AggFunc::GroupConcat | AggFunc::StringAgg => Ok(Type::Text), + #[cfg(feature = "json")] + AggFunc::JsonbGroupArray + | AggFunc::JsonGroupArray + | AggFunc::JsonbGroupObject + | AggFunc::JsonGroupObject => Ok(Type::Text), + AggFunc::External(_) => Ok(Type::Text), // Default for external + }, + LogicalExpr::Alias { expr, .. } => Self::infer_expr_type(expr, schema), + LogicalExpr::IsNull { .. } => Ok(Type::Integer), + LogicalExpr::InList { .. } | LogicalExpr::InSubquery { .. } => Ok(Type::Integer), + LogicalExpr::Exists { .. } => Ok(Type::Integer), + LogicalExpr::Between { .. } => Ok(Type::Integer), + LogicalExpr::Like { .. } => Ok(Type::Integer), + _ => Ok(Type::Text), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; + use turso_parser::parser::Parser; + + fn create_test_schema() -> Schema { + let mut schema = Schema::new(false); + + // Create users table + let users_table = BTreeTable { + name: "users".to_string(), + root_page: 2, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("age".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("email".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(users_table)); + + // Create orders table + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 3, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("amount".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(orders_table)); + + schema + } + + fn parse_and_build(sql: &str, schema: &Schema) -> Result { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser + .next() + .ok_or_else(|| LimboError::ParseError("Empty statement".to_string()))? + .map_err(|e| LimboError::ParseError(e.to_string()))?; + match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(schema); + builder.build_statement(&stmt) + } + _ => Err(LimboError::ParseError( + "Only SQL statements are supported".to_string(), + )), + } + } + + #[test] + fn test_simple_select() { + let schema = create_test_schema(); + let sql = "SELECT id, name FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); + assert!(matches!(proj.exprs[0], LogicalExpr::Column(_))); + assert!(matches!(proj.exprs[1], LogicalExpr::Column(_))); + + match &*proj.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_select_with_filter() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users WHERE age > 18"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + + match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!( + filter.predicate, + LogicalExpr::BinaryExpr { + op: ast::Operator::Greater, + .. + } + )); + + match &*filter.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Filter"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_aggregate_with_group_by() { + let schema = create_test_schema(); + let sql = "SELECT user_id, SUM(amount) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + assert_eq!(agg.schema.column_count(), 2); + + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Sum, + .. + } + )); + + match &*agg.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "orders"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_aggregate_without_group_by() { + let schema = create_test_schema(); + let sql = "SELECT COUNT(*), MAX(age) FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0); + assert_eq!(agg.aggr_expr.len(), 2); + assert_eq!(agg.schema.column_count(), 2); + + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Count, + .. + } + )); + + assert!(matches!( + agg.aggr_expr[1], + LogicalExpr::AggregateFunction { + fun: AggFunc::Max, + .. + } + )); + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_order_by() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users ORDER BY age DESC, name ASC"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 2); + assert!(!sort.exprs[0].asc); // DESC + assert!(sort.exprs[1].asc); // ASC + + match &*sort.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection"), + } + } + _ => panic!("Expected Sort"), + } + } + + #[test] + fn test_limit_offset() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users LIMIT 10 OFFSET 5"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Limit(limit) => { + assert_eq!(limit.fetch, Some(10)); + assert_eq!(limit.skip, Some(5)); + } + _ => panic!("Expected Limit"), + } + } + + #[test] + fn test_order_by_with_limit() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users ORDER BY age DESC LIMIT 5"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should produce: Limit -> Sort -> Projection -> TableScan + match plan { + LogicalPlan::Limit(limit) => { + assert_eq!(limit.fetch, Some(5)); + assert_eq!(limit.skip, None); + + match &*limit.input { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(!sort.exprs[0].asc); // DESC + + match &*sort.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection under Sort"), + } + } + _ => panic!("Expected Sort under Limit"), + } + } + _ => panic!("Expected Limit at top level"), + } + } + + #[test] + fn test_distinct() { + let schema = create_test_schema(); + let sql = "SELECT DISTINCT name FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Distinct(distinct) => match &*distinct.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection"), + }, + _ => panic!("Expected Distinct"), + } + } + + #[test] + fn test_union() { + let schema = create_test_schema(); + let sql = "SELECT id FROM users UNION SELECT user_id FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Union(union) => { + assert!(!union.all); + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_union_all() { + let schema = create_test_schema(); + let sql = "SELECT id FROM users UNION ALL SELECT user_id FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Union(union) => { + assert!(union.all); + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_union_with_order_by() { + let schema = create_test_schema(); + let sql = "SELECT id, name FROM users UNION SELECT user_id, name FROM orders ORDER BY id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(sort.exprs[0].asc); // Default ASC + + match &*sort.input { + LogicalPlan::Union(union) => { + assert!(!union.all); // UNION (not UNION ALL) + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union under Sort"), + } + } + _ => panic!("Expected Sort at top level"), + } + } + + #[test] + fn test_with_cte() { + let schema = create_test_schema(); + let sql = "WITH active_users AS (SELECT * FROM users WHERE age > 18) SELECT name FROM active_users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::WithCTE(with) => { + assert_eq!(with.ctes.len(), 1); + assert!(with.ctes.contains_key("active_users")); + + let cte = &with.ctes["active_users"]; + match &**cte { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(_) => {} + _ => panic!("Expected Filter in CTE"), + }, + _ => panic!("Expected Projection in CTE"), + } + + match &*with.body { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::CTERef(cte_ref) => { + assert_eq!(cte_ref.name, "active_users"); + } + _ => panic!("Expected CTERef"), + }, + _ => panic!("Expected Projection in body"), + } + } + _ => panic!("Expected WithCTE"), + } + } + + #[test] + fn test_case_expression() { + let schema = create_test_schema(); + let sql = "SELECT CASE WHEN age < 18 THEN 'minor' WHEN age < 65 THEN 'adult' ELSE 'senior' END FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + assert!(matches!(proj.exprs[0], LogicalExpr::Case { .. })); + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_in_list() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE id IN (1, 2, 3)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::InList { list, negated, .. } => { + assert!(!negated); + assert_eq!(list.len(), 3); + } + _ => panic!("Expected InList"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_in_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!(filter.predicate, LogicalExpr::InSubquery { .. })); + } + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_exists_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!(filter.predicate, LogicalExpr::Exists { .. })); + } + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_between() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE age BETWEEN 18 AND 65"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::Between { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected Between"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_like() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE name LIKE 'John%'"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::Like { + negated, escape, .. + } => { + assert!(!negated); + assert!(escape.is_none()); + } + _ => panic!("Expected Like"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_is_null() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE email IS NULL"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::IsNull { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected IsNull, got: {:?}", filter.predicate), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_is_not_null() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE email IS NOT NULL"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::IsNull { negated, .. } => { + assert!(negated); + } + _ => panic!("Expected IsNull"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_values_clause() { + let schema = create_test_schema(); + let sql = "SELECT * FROM (VALUES (1, 'a'), (2, 'b'))"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Values(values) => { + assert_eq!(values.rows.len(), 2); + assert_eq!(values.rows[0].len(), 2); + } + _ => panic!("Expected Values"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_complex_expression_with_aggregation() { + // Test: SELECT sum(id + 2) * 2 FROM orders GROUP BY user_id + let schema = create_test_schema(); + + // Test the complex case: sum((id + 2)) * 2 with parentheses + let sql = "SELECT sum((id + 2)) * 2 FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Multiply); + assert!(matches!(**left, LogicalExpr::Column(_))); + assert!(matches!(**right, LogicalExpr::Literal(_))); + } + _ => panic!("Expected BinaryExpr in projection"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + + assert_eq!(agg.aggr_expr.len(), 1); + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert_eq!(*fun, AggregateFunction::Sum); + assert_eq!(args.len(), 1); + match &args[0] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_arg_proj_")); + } + _ => panic!("Expected Column reference to projected expression in aggregate args, got {:?}", args[0]), + } + } + _ => panic!("Expected AggregateFunction"), + } + + match &*agg.input { + LogicalPlan::Projection(inner_proj) => { + assert!(inner_proj.exprs.len() >= 2); + let has_binary_add = inner_proj.exprs.iter().any(|e| { + matches!( + e, + LogicalExpr::BinaryExpr { + op: BinaryOperator::Add, + .. + } + ) + }); + assert!( + has_binary_add, + "Should have id + 2 expression in inner projection" + ); + } + _ => panic!("Expected Projection as input to Aggregate"), + } + } + _ => panic!("Expected Aggregate under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_function_on_aggregate_result() { + let schema = create_test_schema(); + + let sql = "SELECT abs(sum(id)) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "abs"); + assert_eq!(args.len(), 1); + assert!(matches!(args[0], LogicalExpr::Column(_))); + } + _ => panic!("Expected ScalarFunction in projection"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_multiple_aggregates_with_arithmetic() { + let schema = create_test_schema(); + + let sql = "SELECT sum(id) * 2 + count(*) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { op, .. } => { + assert_eq!(*op, BinaryOperator::Add); + } + _ => panic!("Expected BinaryExpr"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 2); + } + _ => panic!("Expected Aggregate"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_projection_aggregation_projection() { + let schema = create_test_schema(); + + // This tests: projection -> aggregation -> projection + // The inner projection computes (id + 2), then we aggregate sum(), then apply abs() + let sql = "SELECT abs(sum(id + 2)) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should produce: Projection(abs) -> Aggregate(sum) -> Projection(id + 2) -> TableScan + match plan { + LogicalPlan::Projection(outer_proj) => { + assert_eq!(outer_proj.exprs.len(), 1); + + // Outer projection should apply abs() function + match &outer_proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "abs"); + assert_eq!(args.len(), 1); + assert!(matches!(args[0], LogicalExpr::Column(_))); + } + _ => panic!("Expected abs() function in outer projection"), + } + + // Next should be the Aggregate + match &*outer_proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + + // The aggregate should be summing a column reference + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert_eq!(*fun, AggregateFunction::Sum); + assert_eq!(args.len(), 1); + + // Should reference the projected column + match &args[0] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_arg_proj_")); + } + _ => panic!("Expected column reference in aggregate"), + } + } + _ => panic!("Expected AggregateFunction"), + } + + // Input to aggregate should be a projection computing id + 2 + match &*agg.input { + LogicalPlan::Projection(inner_proj) => { + // Should have at least the group column and the computed expression + assert!(inner_proj.exprs.len() >= 2); + + // Check for the id + 2 expression + let has_add_expr = inner_proj.exprs.iter().any(|e| { + matches!( + e, + LogicalExpr::BinaryExpr { + op: BinaryOperator::Add, + .. + } + ) + }); + assert!( + has_add_expr, + "Should have id + 2 expression in inner projection" + ); + } + _ => panic!("Expected inner Projection under Aggregate"), + } + } + _ => panic!("Expected Aggregate under outer Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_group_by_validation_allow_grouped_column() { + let schema = create_test_schema(); + + // Test that grouped columns are allowed + let sql = "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id"; + let result = parse_and_build(sql, &schema); + + assert!(result.is_ok(), "Should allow grouped column in SELECT"); + } + + #[test] + fn test_group_by_validation_allow_constants() { + let schema = create_test_schema(); + + // Test that simple constants are allowed even when not grouped + let sql = "SELECT user_id, 42, COUNT(*) FROM orders GROUP BY user_id"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should allow simple constants in SELECT with GROUP BY" + ); + + let sql_complex = "SELECT user_id, (100 + 50) * 2, COUNT(*) FROM orders GROUP BY user_id"; + let result_complex = parse_and_build(sql_complex, &schema); + + assert!( + result_complex.is_ok(), + "Should allow complex constant expressions in SELECT with GROUP BY" + ); + } + + #[test] + fn test_parenthesized_aggregate_expressions() { + let schema = create_test_schema(); + + let sql = "SELECT 25, (MAX(id) / 3), 39 FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should handle parenthesized aggregate expressions" + ); + + let plan = result.unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3); + + assert!(matches!( + proj.exprs[0], + LogicalExpr::Literal(Value::Integer(25)) + )); + + match &proj.exprs[1] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Divide); + assert!(matches!(&**left, LogicalExpr::Column(_))); + assert!(matches!(&**right, LogicalExpr::Literal(Value::Integer(3)))); + } + _ => panic!("Expected BinaryExpr for (MAX(id) / 3)"), + } + + assert!(matches!( + proj.exprs[2], + LogicalExpr::Literal(Value::Integer(39)) + )); + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 1); + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Max, + .. + } + )); + } + _ => panic!("Expected Aggregate node under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_duplicate_aggregate_reuse() { + let schema = create_test_schema(); + + let sql = "SELECT (COUNT(*) - 225), 30, COUNT(*) FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!(result.is_ok(), "Should handle duplicate aggregates"); + + let plan = result.unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3); + + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Subtract); + match &**left { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)"); + } + _ => panic!("Expected Column reference for COUNT(*)"), + } + assert!(matches!( + &**right, + LogicalExpr::Literal(Value::Integer(225)) + )); + } + _ => panic!("Expected BinaryExpr for (COUNT(*) - 225)"), + } + + assert!(matches!( + proj.exprs[1], + LogicalExpr::Literal(Value::Integer(30)) + )); + + match &proj.exprs[2] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)"); + } + _ => panic!("Expected Column reference for COUNT(*)"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have only one COUNT(*) aggregate" + ); + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Count, + .. + } + )); + } + _ => panic!("Expected Aggregate node under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_aggregate_without_group_by_allow_constants() { + let schema = create_test_schema(); + + // Test that constants are allowed with aggregates even without GROUP BY + let sql = "SELECT 42, COUNT(*), MAX(amount) FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should allow simple constants with aggregates without GROUP BY" + ); + + // Test complex constant expressions + let sql_complex = "SELECT (9 / 6) % 5, COUNT(*), MAX(amount) FROM orders"; + let result_complex = parse_and_build(sql_complex, &schema); + + assert!( + result_complex.is_ok(), + "Should allow complex constant expressions with aggregates without GROUP BY" + ); + } + + #[test] + fn test_aggregate_without_group_by_creates_aggregate_node() { + let schema = create_test_schema(); + + // Test that aggregate without GROUP BY creates proper Aggregate node + let sql = "SELECT MAX(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should be: Aggregate -> TableScan (no projection needed for simple aggregate) + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0, "Should have no group expressions"); + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have one aggregate expression" + ); + assert_eq!( + agg.schema.column_count(), + 1, + "Schema should have one column" + ); + } + _ => panic!("Expected Aggregate at top level (no projection)"), + } + } + + #[test] + fn test_scalar_vs_aggregate_function_classification() { + let schema = create_test_schema(); + + // Test MIN/MAX with 1 argument - should be aggregate + let sql = "SELECT MIN(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 1, "MIN(x) should be an aggregate"); + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert!(matches!(fun, AggFunc::Min)); + assert_eq!(args.len(), 1); + } + _ => panic!("Expected AggregateFunction"), + } + } + _ => panic!("Expected Aggregate node for MIN(x)"), + } + + // Test MIN/MAX with 2 arguments - should be scalar in projection + let sql = "SELECT MIN(amount, user_id) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1, "Should have one projection expression"); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!( + fun.to_lowercase(), + "min", + "MIN(x,y) should be a scalar function" + ); + assert_eq!(args.len(), 2); + } + _ => panic!("Expected ScalarFunction for MIN(x,y)"), + } + } + _ => panic!("Expected Projection node for scalar MIN(x,y)"), + } + + // Test MAX with 3 arguments - should be scalar + let sql = "SELECT MAX(amount, user_id, id) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!( + fun.to_lowercase(), + "max", + "MAX(x,y,z) should be a scalar function" + ); + assert_eq!(args.len(), 3); + } + _ => panic!("Expected ScalarFunction for MAX(x,y,z)"), + } + } + _ => panic!("Expected Projection node for scalar MAX(x,y,z)"), + } + + // Test that MIN with 0 args is treated as scalar (will fail later in execution) + let sql = "SELECT MIN() FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun.to_lowercase(), "min"); + assert_eq!(args.len(), 0, "MIN() should be scalar with 0 args"); + } + _ => panic!("Expected ScalarFunction for MIN()"), + }, + _ => panic!("Expected Projection for MIN()"), + } + + // Test other functions that are always aggregate (COUNT, SUM, AVG) + let sql = "SELECT COUNT(*), SUM(amount), AVG(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 3, "Should have 3 aggregate functions"); + for expr in &agg.aggr_expr { + assert!(matches!(expr, LogicalExpr::AggregateFunction { .. })); + } + } + _ => panic!("Expected Aggregate node"), + } + + // Test scalar functions that are never aggregates (ABS, ROUND, etc.) + let sql = "SELECT ABS(amount), ROUND(amount), LENGTH(product) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3, "Should have 3 scalar functions"); + for expr in &proj.exprs { + match expr { + LogicalExpr::ScalarFunction { .. } => {} + _ => panic!("Expected all ScalarFunctions"), + } + } + } + _ => panic!("Expected Projection node for scalar functions"), + } + } + + #[test] + fn test_mixed_aggregate_and_group_columns() { + let schema = create_test_schema(); + + // When selecting both aggregate and grouping columns + let sql = "SELECT user_id, sum(id) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // No projection needed - aggregate outputs exactly what we select + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + assert_eq!(agg.schema.column_count(), 2); + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_scalar_function_wrapping_aggregate_no_group_by() { + // Test: SELECT HEX(SUM(age + 2)) FROM users + // Expected structure: + // Projection { exprs: [ScalarFunction(HEX, [Column])] } + // -> Aggregate { aggr_expr: [Sum(BinaryExpr(age + 2))], group_expr: [] } + // -> Projection { exprs: [BinaryExpr(age + 2)] } + // -> TableScan("users") + + let schema = create_test_schema(); + let sql = "SELECT HEX(SUM(age + 2)) FROM users"; + let mut parser = Parser::new(sql.as_bytes()); + let stmt = parser.next().unwrap().unwrap(); + + let plan = match stmt { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + builder.build_statement(&stmt).unwrap() + } + _ => panic!("Expected SQL statement"), + }; + + match &plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1, "Should have one expression"); + + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "HEX", "Outer function should be HEX"); + assert_eq!(args.len(), 1, "HEX should have one argument"); + + match &args[0] { + LogicalExpr::Column(_) => {} + LogicalExpr::AggregateFunction { .. } => { + panic!("Aggregate function should not be embedded in projection! It should be in a separate Aggregate operator"); + } + _ => panic!( + "Expected column reference as argument to HEX, got: {:?}", + args[0] + ), + } + } + _ => panic!("Expected ScalarFunction (HEX), got: {:?}", proj.exprs[0]), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0, "Should have no GROUP BY"); + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have one aggregate expression" + ); + + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } => { + assert_eq!(*fun, crate::function::AggFunc::Sum, "Should be SUM"); + assert!(!distinct, "Should not be DISTINCT"); + assert_eq!(args.len(), 1, "SUM should have one argument"); + + match &args[0] { + LogicalExpr::Column(col) => { + // When aggregate arguments are complex, they get pre-projected + assert!(col.name.starts_with("__agg_arg_proj_"), + "Should reference pre-projected column, got: {}", col.name); + } + LogicalExpr::BinaryExpr { left, op, right } => { + // Simple case without pre-projection (shouldn't happen with current implementation) + assert_eq!(*op, ast::Operator::Add, "Should be addition"); + + match (&**left, &**right) { + (LogicalExpr::Column(col), LogicalExpr::Literal(val)) => { + assert_eq!(col.name, "age", "Should reference age column"); + assert_eq!(*val, Value::Integer(2), "Should add 2"); + } + _ => panic!("Expected age + 2"), + } + } + _ => panic!("Expected Column reference or BinaryExpr for aggregate argument, got: {:?}", args[0]), + } + } + _ => panic!("Expected AggregateFunction"), + } + + match &*agg.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan under projection"), + }, + _ => panic!("Expected TableScan or Projection under Aggregate"), + } + } + _ => panic!( + "Expected Aggregate operator under Projection, got: {:?}", + proj.input + ), + } + } + _ => panic!("Expected Projection as top-level operator, got: {plan:?}"), + } + } +} diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 7d29a8173..2de196a8f 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -21,6 +21,8 @@ pub(crate) mod group_by; pub(crate) mod index; pub(crate) mod insert; pub(crate) mod integrity_check; +#[cfg(test)] +pub(crate) mod logical; pub(crate) mod main_loop; pub(crate) mod optimizer; pub(crate) mod order_by; From dbe29e4bab4bf8493ff302752727270567b5bbf6 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Mon, 25 Aug 2025 21:00:30 -0500 Subject: [PATCH 2/9] fix aggregator operator It needs to keep track of the old values to emit retractions (when the aggregation changes, remove old value, insert new) --- core/incremental/operator.rs | 281 +++++++++++++++++++++++++++++++++-- 1 file changed, 268 insertions(+), 13 deletions(-) diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 0391e3c0a..38e1689da 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -1472,8 +1472,9 @@ impl AggregateOperator { pub fn process_delta(&mut self, delta: Delta) -> Delta { let mut output_delta = Delta::new(); - // Track which groups were modified + // Track which groups were modified and their old values let mut modified_groups = HashSet::new(); + let mut old_values: HashMap> = HashMap::new(); // Process each change in the delta for (row, weight) in &delta.changes { @@ -1484,6 +1485,17 @@ impl AggregateOperator { // Extract group key let group_key = self.extract_group_key(&row.values); let group_key_str = Self::group_key_to_string(&group_key); + + // Store old aggregate values BEFORE applying the delta + // (only for the first time we see this group in this batch) + if !modified_groups.contains(&group_key_str) { + if let Some(state) = self.group_states.get(&group_key_str) { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&self.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + } + } + modified_groups.insert(group_key_str.clone()); // Store the actual group key values @@ -1514,17 +1526,25 @@ impl AggregateOperator { .cloned() .unwrap_or_default(); + // Generate a unique key for this group + // We use a hash of the group key to ensure consistency + let result_key = group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)); + + // Emit retraction for old value if it existed + if let Some(old_row_values) = old_values.get(&group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row.clone(), -1)); + // Also remove from current state + self.current_state.changes.push((old_row, -1)); + } + if let Some(state) = self.group_states.get(&group_key_str) { // Build output row: group_by columns + aggregate values let mut output_values = group_key.clone(); output_values.extend(state.to_values(&self.aggregates)); - // Generate a unique key for this group - // We use a hash of the group key to ensure consistency - let result_key = group_key_str - .bytes() - .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)); - // Check if group should be removed (count is 0) if state.count > 0 { // Add to output delta with positive weight @@ -1534,12 +1554,8 @@ impl AggregateOperator { // Update current state self.current_state.changes.push((output_row, 1)); } else { - // Add to output delta with negative weight (deletion) - let output_row = HashableRow::new(result_key, output_values); - output_delta.changes.push((output_row.clone(), -1)); - - // Mark for removal in current state - self.current_state.changes.push((output_row, -1)); + // Group has count=0, remove from state + // (we already emitted the retraction above if needed) self.group_states.remove(&group_key_str); self.group_key_values.remove(&group_key_str); } @@ -1603,6 +1619,245 @@ mod tests { ); } + // Aggregate tests + #[test] + fn test_aggregate_incremental_update_emits_retraction() { + // This test verifies that when an aggregate value changes, + // the operator emits both a retraction (-1) of the old value + // and an insertion (+1) of the new value. + + // Create an aggregate operator for SUM(age) with no GROUP BY + let mut agg = AggregateOperator::new( + vec![], // No GROUP BY + vec![AggregateFunction::Sum("age".to_string())], + vec!["id".to_string(), "name".to_string(), "age".to_string()], + ); + + // Initial data: 3 users + let mut initial_delta = Delta::new(); + initial_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".to_string().into()), + Value::Integer(25), + ], + ); + initial_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".to_string().into()), + Value::Integer(30), + ], + ); + initial_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".to_string().into()), + Value::Integer(35), + ], + ); + + // Initialize with initial data + agg.initialize(initial_delta); + + // Verify initial state: SUM(age) = 25 + 30 + 35 = 90 + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 1, "Should have one aggregate row"); + let (row, weight) = &state.changes[0]; + assert_eq!(*weight, 1, "Aggregate row should have weight 1"); + assert_eq!(row.values[0], Value::Float(90.0), "SUM should be 90"); + + // Now add a new user (incremental update) + let mut update_delta = Delta::new(); + update_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".to_string().into()), + Value::Integer(40), + ], + ); + + // Process the incremental update + let output_delta = agg.process_delta(update_delta); + + // CRITICAL: The output delta should contain TWO changes: + // 1. Retraction of old aggregate value (90) with weight -1 + // 2. Insertion of new aggregate value (130) with weight +1 + assert_eq!( + output_delta.changes.len(), + 2, + "Expected 2 changes (retraction + insertion), got {}: {:?}", + output_delta.changes.len(), + output_delta.changes + ); + + // Verify the retraction comes first + let (retraction_row, retraction_weight) = &output_delta.changes[0]; + assert_eq!( + *retraction_weight, -1, + "First change should be a retraction" + ); + assert_eq!( + retraction_row.values[0], + Value::Float(90.0), + "Retracted value should be the old sum (90)" + ); + + // Verify the insertion comes second + let (insertion_row, insertion_weight) = &output_delta.changes[1]; + assert_eq!(*insertion_weight, 1, "Second change should be an insertion"); + assert_eq!( + insertion_row.values[0], + Value::Float(130.0), + "Inserted value should be the new sum (130)" + ); + + // Both changes should have the same row ID (since it's the same aggregate group) + assert_eq!( + retraction_row.rowid, insertion_row.rowid, + "Retraction and insertion should have the same row ID" + ); + } + + #[test] + fn test_aggregate_with_group_by_emits_retractions() { + // This test verifies that when aggregate values change for grouped data, + // the operator emits both retractions and insertions correctly for each group. + + // Create an aggregate operator for SUM(score) GROUP BY team + let mut agg = AggregateOperator::new( + vec!["team".to_string()], // GROUP BY team + vec![AggregateFunction::Sum("score".to_string())], + vec![ + "id".to_string(), + "team".to_string(), + "player".to_string(), + "score".to_string(), + ], + ); + + // Initial data: players on different teams + let mut initial_delta = Delta::new(); + initial_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("red".to_string().into()), + Value::Text("Alice".to_string().into()), + Value::Integer(10), + ], + ); + initial_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("blue".to_string().into()), + Value::Text("Bob".to_string().into()), + Value::Integer(15), + ], + ); + initial_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("red".to_string().into()), + Value::Text("Charlie".to_string().into()), + Value::Integer(20), + ], + ); + + // Initialize with initial data + agg.initialize(initial_delta); + + // Verify initial state: red team = 30, blue team = 15 + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 2, "Should have two groups"); + + // Find the red and blue team aggregates + let mut red_sum = None; + let mut blue_sum = None; + for (row, weight) in &state.changes { + assert_eq!(*weight, 1); + if let Value::Text(team) = &row.values[0] { + if team.as_str() == "red" { + red_sum = Some(&row.values[1]); + } else if team.as_str() == "blue" { + blue_sum = Some(&row.values[1]); + } + } + } + assert_eq!( + red_sum, + Some(&Value::Float(30.0)), + "Red team sum should be 30" + ); + assert_eq!( + blue_sum, + Some(&Value::Float(15.0)), + "Blue team sum should be 15" + ); + + // Now add a new player to the red team (incremental update) + let mut update_delta = Delta::new(); + update_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("red".to_string().into()), + Value::Text("David".to_string().into()), + Value::Integer(25), + ], + ); + + // Process the incremental update + let output_delta = agg.process_delta(update_delta); + + // Should have 2 changes: retraction of old red team sum, insertion of new red team sum + // Blue team should NOT be affected + assert_eq!( + output_delta.changes.len(), + 2, + "Expected 2 changes for red team only, got {}: {:?}", + output_delta.changes.len(), + output_delta.changes + ); + + // Both changes should be for the red team + let mut found_retraction = false; + let mut found_insertion = false; + + for (row, weight) in &output_delta.changes { + if let Value::Text(team) = &row.values[0] { + assert_eq!(team.as_str(), "red", "Only red team should have changes"); + + if *weight == -1 { + // Retraction of old value + assert_eq!( + row.values[1], + Value::Float(30.0), + "Should retract old sum of 30" + ); + found_retraction = true; + } else if *weight == 1 { + // Insertion of new value + assert_eq!( + row.values[1], + Value::Float(55.0), + "Should insert new sum of 55" + ); + found_insertion = true; + } + } + } + + assert!(found_retraction, "Should have found retraction"); + assert!(found_insertion, "Should have found insertion"); + } + // Join tests #[test] fn test_join_uses_delta_formula() { From 6e2bd364ee05c48bb979e65fa9ea8b284762cbd0 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 26 Aug 2025 12:11:31 -0500 Subject: [PATCH 3/9] fix issue with rowids and deletions The operator itself should handle deletions and updates that change the rowid by consolidating its state. Our current materialized views track state themselves, so we don't see this problem now. But it becomes apparent once we switch the views to use circuits. --- core/incremental/operator.rs | 72 ++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 38e1689da..34e2e2f1b 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -575,7 +575,10 @@ impl IncrementalOperator for FilterOperator { } fn get_current_state(&self) -> Delta { - self.current_state.clone() + // Return a consolidated view of the current state + let mut consolidated = self.current_state.clone(); + consolidated.consolidate(); + consolidated } fn set_tracker(&mut self, tracker: Arc>) { @@ -918,7 +921,10 @@ impl IncrementalOperator for ProjectOperator { } fn get_current_state(&self) -> Delta { - self.current_state.clone() + // Return a consolidated view of the current state + let mut consolidated = self.current_state.clone(); + consolidated.consolidate(); + consolidated } fn set_tracker(&mut self, tracker: Arc>) { @@ -1584,7 +1590,10 @@ impl IncrementalOperator for AggregateOperator { } fn get_current_state(&self) -> Delta { - self.current_state.clone() + // Return a consolidated view of the current state + let mut consolidated = self.current_state.clone(); + consolidated.consolidate(); + consolidated } fn set_tracker(&mut self, tracker: Arc>) { @@ -2390,4 +2399,61 @@ mod tests { assert_eq!(change_row.values[1], Value::Integer(1)); // count: 2 - 1 assert_eq!(change_row.values[2], Value::Integer(200)); // sum: 300 - 100 } + + #[test] + fn test_filter_operator_rowid_update() { + // When a row's rowid changes (e.g., UPDATE t SET a=1 WHERE a=3 on INTEGER PRIMARY KEY), + // the operator should properly consolidate the state + + let mut filter = FilterOperator::new( + FilterPredicate::GreaterThan { + column: "b".to_string(), + value: Value::Integer(2), + }, + vec!["a".to_string(), "b".to_string()], + ); + + // Initialize with a row (rowid=3, values=[3, 3]) + let mut init_data = Delta::new(); + init_data.insert(3, vec![Value::Integer(3), Value::Integer(3)]); + filter.initialize(init_data); + + // Check initial state + let state = filter.get_current_state(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.rowid, 3); + assert_eq!( + state.changes[0].0.values, + vec![Value::Integer(3), Value::Integer(3)] + ); + + // Simulate an UPDATE that changes rowid from 3 to 1 + // This is sent as: delete(3) + insert(1) + let mut update_delta = Delta::new(); + update_delta.delete(3, vec![Value::Integer(3), Value::Integer(3)]); + update_delta.insert(1, vec![Value::Integer(1), Value::Integer(3)]); + + let output = filter.process_delta(update_delta); + + // The output delta should have both changes (both pass the filter b > 2) + assert_eq!(output.changes.len(), 2); + assert_eq!(output.changes[0].1, -1); // delete weight + assert_eq!(output.changes[1].1, 1); // insert weight + + // The current state should be consolidated to only have rows with positive weight + let final_state = filter.get_current_state(); + + // After consolidation, we should have only one row with rowid=1 + assert_eq!( + final_state.changes.len(), + 1, + "State should be consolidated to have only one row" + ); + assert_eq!(final_state.changes[0].0.rowid, 1); + assert_eq!( + final_state.changes[0].0.values, + vec![Value::Integer(1), Value::Integer(3)] + ); + assert_eq!(final_state.changes[0].1, 1); // positive weight + } } From 05b275f865d2d345a055fc55a7215b77d4dd5013 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 26 Aug 2025 13:17:34 -0500 Subject: [PATCH 4/9] remove min/max and add more tests for other aggregations min/max require O(N) storage because of deletions. It is easy to see why: if you *add* a new row, you can quickly and incrementally check if it is smaller / larger than the previous accumulator. But when you *delete* a row you can't do that and have to check the previous values. Feldera uses something called "traces" which to me look a lot like indexes. When we implement materialization, this is easy to do. But to avoid having something broken, we'll just disable min / max until then. --- core/incremental/operator.rs | 366 +++++++++++++++++++++++++++-------- 1 file changed, 281 insertions(+), 85 deletions(-) diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 34e2e2f1b..867d25352 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -364,8 +364,7 @@ pub enum AggregateFunction { Count, Sum(String), Avg(String), - Min(String), - Max(String), + // MIN and MAX are not supported - see comment in compiler.rs for explanation } impl Display for AggregateFunction { @@ -374,8 +373,6 @@ impl Display for AggregateFunction { AggregateFunction::Count => write!(f, "COUNT(*)"), AggregateFunction::Sum(col) => write!(f, "SUM({col})"), AggregateFunction::Avg(col) => write!(f, "AVG({col})"), - AggregateFunction::Min(col) => write!(f, "MIN({col})"), - AggregateFunction::Max(col) => write!(f, "MAX({col})"), } } } @@ -401,8 +398,8 @@ impl AggregateFunction { AggFunc::Count | AggFunc::Count0 => Some(AggregateFunction::Count), AggFunc::Sum => input_column.map(AggregateFunction::Sum), AggFunc::Avg => input_column.map(AggregateFunction::Avg), - AggFunc::Min => input_column.map(AggregateFunction::Min), - AggFunc::Max => input_column.map(AggregateFunction::Max), + // MIN and MAX are not supported in incremental views - see compiler.rs + AggFunc::Min | AggFunc::Max => None, _ => None, // Other aggregate functions not yet supported in DBSP } } @@ -1281,10 +1278,8 @@ struct AggregateState { sums: HashMap, // For AVG: column_name -> (sum, count) for computing average avgs: HashMap, - // For MIN: column_name -> min value - mins: HashMap, - // For MAX: column_name -> max value - maxs: HashMap, + // MIN/MAX are not supported - they require O(n) storage overhead for handling deletions + // correctly. See comment in apply_delta() for details. } impl AggregateState { @@ -1293,8 +1288,6 @@ impl AggregateState { count: 0, sums: HashMap::new(), avgs: HashMap::new(), - mins: HashMap::new(), - maxs: HashMap::new(), } } @@ -1343,43 +1336,6 @@ impl AggregateState { } } } - AggregateFunction::Min(col_name) => { - // MIN/MAX are more complex for incremental updates - // For now, we'll need to recompute from the full state - // This is a limitation we can improve later - if weight > 0 { - // Only update on insert - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - self.mins - .entry(col_name.clone()) - .and_modify(|existing| { - if val < existing { - *existing = val.clone(); - } - }) - .or_insert_with(|| val.clone()); - } - } - } - } - AggregateFunction::Max(col_name) => { - if weight > 0 { - // Only update on insert - if let Some(idx) = column_names.iter().position(|c| c == col_name) { - if let Some(val) = values.get(idx) { - self.maxs - .entry(col_name.clone()) - .and_modify(|existing| { - if val > existing { - *existing = val.clone(); - } - }) - .or_insert_with(|| val.clone()); - } - } - } - } } } } @@ -1413,12 +1369,6 @@ impl AggregateState { result.push(Value::Null); } } - AggregateFunction::Min(col_name) => { - result.push(self.mins.get(col_name).cloned().unwrap_or(Value::Null)); - } - AggregateFunction::Max(col_name) => { - result.push(self.maxs.get(col_name).cloned().unwrap_or(Value::Null)); - } } } @@ -2090,12 +2040,14 @@ mod tests { // Should only update one group (cat_0), not recount all groups assert_eq!(tracker.lock().unwrap().aggregation_updates, 1); - // Output should show cat_0 now has count 11 - assert_eq!(output.len(), 1); - assert!(!output.changes.is_empty()); - let (change_row, _weight) = &output.changes[0]; - assert_eq!(change_row.values[0], Value::Text(Text::new("cat_0"))); - assert_eq!(change_row.values[1], Value::Integer(11)); + // Check the final state - cat_0 should now have count 11 + let final_state = agg.get_current_state(); + let cat_0 = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("cat_0"))) + .unwrap(); + assert_eq!(cat_0.0.values[1], Value::Integer(11)); // Verify incremental behavior let t = tracker.lock().unwrap(); @@ -2174,13 +2126,15 @@ mod tests { // Should only update Widget group assert_eq!(tracker.lock().unwrap().aggregation_updates, 1); - assert_eq!(output.len(), 1); - // Widget should now be 300 (250 + 50) - assert!(!output.changes.is_empty()); - let (change, _weight) = &output.changes[0]; - assert_eq!(change.values[0], Value::Text(Text::new("Widget"))); - assert_eq!(change.values[1], Value::Integer(300)); + // Check final state - Widget should now be 300 (250 + 50) + let final_state = agg.get_current_state(); + let widget = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("Widget"))) + .unwrap(); + assert_eq!(widget.0.values[1], Value::Integer(300)); } #[test] @@ -2247,13 +2201,15 @@ mod tests { ); let output = agg.process_delta(delta); - // Should only update user 1 - assert_eq!(output.len(), 1); - assert!(!output.changes.is_empty()); - let (change, _weight) = &output.changes[0]; - assert_eq!(change.values[0], Value::Integer(1)); // user_id - assert_eq!(change.values[1], Value::Integer(3)); // count: 2 + 1 - assert_eq!(change.values[2], Value::Integer(350)); // sum: 300 + 50 + // Check final state - user 1 should have updated count and sum + let final_state = agg.get_current_state(); + let user1 = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Integer(1)) + .unwrap(); + assert_eq!(user1.0.values[1], Value::Integer(3)); // count: 2 + 1 + assert_eq!(user1.0.values[2], Value::Integer(350)); // sum: 300 + 50 } #[test] @@ -2329,11 +2285,14 @@ mod tests { ); let output = agg.process_delta(delta); - // Category A avg should now be (10 + 20 + 30) / 3 = 20 - assert!(!output.changes.is_empty()); - let (change, _weight) = &output.changes[0]; - assert_eq!(change.values[0], Value::Text(Text::new("A"))); - assert_eq!(change.values[1], Value::Float(20.0)); + // Check final state - Category A avg should now be (10 + 20 + 30) / 3 = 20 + let final_state = agg.get_current_state(); + let cat_a = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("A"))) + .unwrap(); + assert_eq!(cat_a.0.values[1], Value::Float(20.0)); } #[test] @@ -2392,12 +2351,249 @@ mod tests { let output = agg.process_delta(delta); - // Should update to count=1, sum=200 - assert!(!output.changes.is_empty()); - let (change_row, _weight) = &output.changes[0]; - assert_eq!(change_row.values[0], Value::Text(Text::new("A"))); - assert_eq!(change_row.values[1], Value::Integer(1)); // count: 2 - 1 - assert_eq!(change_row.values[2], Value::Integer(200)); // sum: 300 - 100 + // Check final state - should update to count=1, sum=200 + let final_state = agg.get_current_state(); + let cat_a = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text(Text::new("A"))) + .unwrap(); + assert_eq!(cat_a.0.values[1], Value::Integer(1)); // count: 2 - 1 + assert_eq!(cat_a.0.values[2], Value::Integer(200)); // sum: 300 - 100 + } + + #[test] + fn test_count_aggregation_with_deletions() { + let aggregates = vec![AggregateFunction::Count]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(10)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(20)]); + init_data.insert(3, vec![Value::Text("B".into()), Value::Integer(30)]); + agg.initialize(init_data); + + // Check initial counts + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 2); + + // Find group A and B + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + let group_b = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("B".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(2)); // COUNT = 2 for A + assert_eq!(group_b.0.values[1], Value::Integer(1)); // COUNT = 1 for B + + // Delete one row from group A + let mut delete_delta = Delta::new(); + delete_delta.delete(1, vec![Value::Text("A".into()), Value::Integer(10)]); + + let output = agg.process_delta(delete_delta); + + // Should emit retraction for old count and insertion for new count + assert_eq!(output.changes.len(), 2); + + // Check final state + let final_state = agg.get_current_state(); + let group_a_final = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!(group_a_final.0.values[1], Value::Integer(1)); // COUNT = 1 for A after deletion + + // Delete all rows from group B + let mut delete_all_b = Delta::new(); + delete_all_b.delete(3, vec![Value::Text("B".into()), Value::Integer(30)]); + + let output_b = agg.process_delta(delete_all_b); + assert_eq!(output_b.changes.len(), 1); // Only retraction, no new row + assert_eq!(output_b.changes[0].1, -1); // Retraction + + // Final state should not have group B + let final_state2 = agg.get_current_state(); + assert_eq!(final_state2.changes.len(), 1); // Only group A remains + assert_eq!(final_state2.changes[0].0.values[0], Value::Text("A".into())); + } + + #[test] + fn test_sum_aggregation_with_deletions() { + let aggregates = vec![AggregateFunction::Sum("value".to_string())]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(10)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(20)]); + init_data.insert(3, vec![Value::Text("B".into()), Value::Integer(30)]); + init_data.insert(4, vec![Value::Text("B".into()), Value::Integer(15)]); + agg.initialize(init_data); + + // Check initial sums + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + let group_b = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("B".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(30)); // SUM = 30 for A (10+20) + assert_eq!(group_b.0.values[1], Value::Integer(45)); // SUM = 45 for B (30+15) + + // Delete one row from group A + let mut delete_delta = Delta::new(); + delete_delta.delete(2, vec![Value::Text("A".into()), Value::Integer(20)]); + + agg.process_delta(delete_delta); + + // Check updated sum + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!(group_a.0.values[1], Value::Integer(10)); // SUM = 10 for A after deletion + + // Delete all from group B + let mut delete_all_b = Delta::new(); + delete_all_b.delete(3, vec![Value::Text("B".into()), Value::Integer(30)]); + delete_all_b.delete(4, vec![Value::Text("B".into()), Value::Integer(15)]); + + agg.process_delta(delete_all_b); + + // Group B should be gone + let final_state = agg.get_current_state(); + assert_eq!(final_state.changes.len(), 1); // Only group A remains + assert_eq!(final_state.changes[0].0.values[0], Value::Text("A".into())); + } + + #[test] + fn test_avg_aggregation_with_deletions() { + let aggregates = vec![AggregateFunction::Avg("value".to_string())]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(10)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(20)]); + init_data.insert(3, vec![Value::Text("A".into()), Value::Integer(30)]); + agg.initialize(init_data); + + // Check initial average + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.values[1], Value::Float(20.0)); // AVG = (10+20+30)/3 = 20 + + // Delete the middle value + let mut delete_delta = Delta::new(); + delete_delta.delete(2, vec![Value::Text("A".into()), Value::Integer(20)]); + + agg.process_delta(delete_delta); + + // Check updated average + let state = agg.get_current_state(); + assert_eq!(state.changes[0].0.values[1], Value::Float(20.0)); // AVG = (10+30)/2 = 20 (same!) + + // Delete another to change the average + let mut delete_another = Delta::new(); + delete_another.delete(3, vec![Value::Text("A".into()), Value::Integer(30)]); + + agg.process_delta(delete_another); + + let state = agg.get_current_state(); + assert_eq!(state.changes[0].0.values[1], Value::Float(10.0)); // AVG = 10/1 = 10 + } + + #[test] + fn test_multiple_aggregations_with_deletions() { + // Test COUNT, SUM, and AVG together + let aggregates = vec![ + AggregateFunction::Count, + AggregateFunction::Sum("value".to_string()), + AggregateFunction::Avg("value".to_string()), + ]; + let group_by = vec!["category".to_string()]; + let input_columns = vec!["category".to_string(), "value".to_string()]; + + let mut agg = AggregateOperator::new(group_by, aggregates.clone(), input_columns); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Text("A".into()), Value::Integer(100)]); + init_data.insert(2, vec![Value::Text("A".into()), Value::Integer(200)]); + init_data.insert(3, vec![Value::Text("B".into()), Value::Integer(50)]); + agg.initialize(init_data); + + // Check initial state + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(2)); // COUNT = 2 + assert_eq!(group_a.0.values[2], Value::Integer(300)); // SUM = 300 + assert_eq!(group_a.0.values[3], Value::Float(150.0)); // AVG = 150 + + // Delete one row from group A + let mut delete_delta = Delta::new(); + delete_delta.delete(1, vec![Value::Text("A".into()), Value::Integer(100)]); + + agg.process_delta(delete_delta); + + // Check all aggregates updated correctly + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(1)); // COUNT = 1 + assert_eq!(group_a.0.values[2], Value::Integer(200)); // SUM = 200 + assert_eq!(group_a.0.values[3], Value::Float(200.0)); // AVG = 200 + + // Insert a new row with floating point value + let mut insert_delta = Delta::new(); + insert_delta.insert(4, vec![Value::Text("A".into()), Value::Float(50.5)]); + + agg.process_delta(insert_delta); + + let state = agg.get_current_state(); + let group_a = state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + + assert_eq!(group_a.0.values[1], Value::Integer(2)); // COUNT = 2 + assert_eq!(group_a.0.values[2], Value::Float(250.5)); // SUM = 250.5 + assert_eq!(group_a.0.values[3], Value::Float(125.25)); // AVG = 125.25 } #[test] From 7e4bacca5552e7990cf296797043eb0035d3fb70 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 26 Aug 2025 17:35:40 -0500 Subject: [PATCH 5/9] remove join operator I am 100% sure they are total bullshit by now, since we don't implement the join operator yet. The code evolved a lot, and in every turn there are issues with aggregators, projectors, filters... some subtle, some not so subtle. We keep having to patch join slightly as we make changes to the API, but we don't truly exercise whether or not they keep working because there is no support for them in the views. Therefore: let's remove it. We'll bring it back later. --- core/incremental/operator.rs | 491 ----------------------------------- 1 file changed, 491 deletions(-) diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 867d25352..676a9a067 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -929,325 +929,6 @@ impl IncrementalOperator for ProjectOperator { } } -/// Join operator - performs incremental joins using DBSP formula -/// ∂(A ⋈ B) = A ⋈ ∂B + ∂A ⋈ B + ∂A ⋈ ∂B -#[derive(Debug)] -pub struct JoinOperator { - join_type: JoinType, - pub left_on_column: String, - pub right_on_column: String, - left_column_names: Vec, - right_column_names: Vec, - // Current accumulated state for both sides - left_state: Delta, - right_state: Delta, - // Index for efficient lookups: column_value_as_string -> vec of row_keys - // We use String representation of values since Value doesn't implement Hash - left_index: HashMap>, - right_index: HashMap>, - // Result state - current_state: Delta, - tracker: Option>>, - // For generating unique keys for join results - next_result_key: i64, -} - -impl JoinOperator { - pub fn new( - join_type: JoinType, - left_on_column: String, - right_on_column: String, - left_column_names: Vec, - right_column_names: Vec, - ) -> Self { - Self { - join_type, - left_on_column, - right_on_column, - left_column_names, - right_column_names, - left_state: Delta::new(), - right_state: Delta::new(), - left_index: HashMap::new(), - right_index: HashMap::new(), - current_state: Delta::new(), - tracker: None, - next_result_key: 0, - } - } - - pub fn set_tracker(&mut self, tracker: Arc>) { - self.tracker = Some(tracker); - } - - /// Build index for a side of the join - fn build_index( - state: &Delta, - column_names: &[String], - on_column: &str, - ) -> HashMap> { - let mut index = HashMap::new(); - - // Find the column index - let col_idx = column_names.iter().position(|c| c == on_column); - if col_idx.is_none() { - return index; - } - let col_idx = col_idx.unwrap(); - - // Build the index - for (row, weight) in &state.changes { - // Include rows with positive weight in the index - if *weight > 0 { - if let Some(key_value) = row.values.get(col_idx) { - // Convert value to string for indexing - let key_str = format!("{key_value:?}"); - index - .entry(key_str) - .or_insert_with(Vec::new) - .push(row.rowid); - } - } - } - - index - } - - /// Join two deltas - fn join_deltas(&self, left_delta: &Delta, right_delta: &Delta, next_key: &mut i64) -> Delta { - let mut result = Delta::new(); - - // Find column indices - let left_col_idx = self - .left_column_names - .iter() - .position(|c| c == &self.left_on_column) - .unwrap_or(0); - let right_col_idx = self - .right_column_names - .iter() - .position(|c| c == &self.right_on_column) - .unwrap_or(0); - - // For each row in left_delta - for (left_row, left_weight) in &left_delta.changes { - // Process both inserts and deletes - - let left_join_value = left_row.values.get(left_col_idx); - if left_join_value.is_none() { - continue; - } - let left_join_value = left_join_value.unwrap(); - - // Look up matching rows in right_delta - for (right_row, right_weight) in &right_delta.changes { - // Process both inserts and deletes - - let right_join_value = right_row.values.get(right_col_idx); - if right_join_value.is_none() { - continue; - } - let right_join_value = right_join_value.unwrap(); - - // Check if values match - if left_join_value == right_join_value { - // Record the join lookup - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_join_lookup(); - } - - // Create joined row - let mut joined_values = left_row.values.clone(); - joined_values.extend(right_row.values.clone()); - - // Generate a unique key for the result - let result_key = *next_key; - *next_key += 1; - - let joined_row = HashableRow::new(result_key, joined_values); - result - .changes - .push((joined_row, left_weight * right_weight)); - } - } - } - - result - } - - /// Join a delta with the full state using the index - fn join_delta_with_state( - &self, - delta: &Delta, - state: &Delta, - delta_on_left: bool, - next_key: &mut i64, - ) -> Delta { - let mut result = Delta::new(); - - let (delta_col_idx, state_col_names) = if delta_on_left { - ( - self.left_column_names - .iter() - .position(|c| c == &self.left_on_column) - .unwrap_or(0), - &self.right_column_names, - ) - } else { - ( - self.right_column_names - .iter() - .position(|c| c == &self.right_on_column) - .unwrap_or(0), - &self.left_column_names, - ) - }; - - // Use index for efficient lookup - let state_index = Self::build_index( - state, - state_col_names, - if delta_on_left { - &self.right_on_column - } else { - &self.left_on_column - }, - ); - - for (delta_row, delta_weight) in &delta.changes { - // Process both inserts and deletes - - let delta_join_value = delta_row.values.get(delta_col_idx); - if delta_join_value.is_none() { - continue; - } - let delta_join_value = delta_join_value.unwrap(); - - // Use index to find matching rows - let delta_key_str = format!("{delta_join_value:?}"); - if let Some(matching_keys) = state_index.get(&delta_key_str) { - for state_key in matching_keys { - // Look up in the state - find the row with this rowid - let state_row_opt = state - .changes - .iter() - .find(|(row, weight)| row.rowid == *state_key && *weight > 0); - - if let Some((state_row, state_weight)) = state_row_opt { - // Record the join lookup - if let Some(tracker) = &self.tracker { - tracker.lock().unwrap().record_join_lookup(); - } - - // Create joined row - let joined_values = if delta_on_left { - let mut v = delta_row.values.clone(); - v.extend(state_row.values.clone()); - v - } else { - let mut v = state_row.values.clone(); - v.extend(delta_row.values.clone()); - v - }; - - let result_key = *next_key; - *next_key += 1; - - let joined_row = HashableRow::new(result_key, joined_values); - result - .changes - .push((joined_row, delta_weight * state_weight)); - } - } - } - } - - result - } - - /// Initialize both sides of the join - pub fn initialize_both(&mut self, left_data: Delta, right_data: Delta) { - self.left_state = left_data.clone(); - self.right_state = right_data.clone(); - - // Build indices - self.left_index = Self::build_index( - &self.left_state, - &self.left_column_names, - &self.left_on_column, - ); - self.right_index = Self::build_index( - &self.right_state, - &self.right_column_names, - &self.right_on_column, - ); - - // Perform initial join - let mut next_key = self.next_result_key; - self.current_state = self.join_deltas(&self.left_state, &self.right_state, &mut next_key); - self.next_result_key = next_key; - } - - /// Process deltas for both sides using DBSP formula - /// ∂(A ⋈ B) = A ⋈ ∂B + ∂A ⋈ B + ∂A ⋈ ∂B - pub fn process_both_deltas(&mut self, left_delta: Delta, right_delta: Delta) -> Delta { - let mut result = Delta::new(); - let mut next_key = self.next_result_key; - - // A ⋈ ∂B (existing left with new right) - let a_join_db = - self.join_delta_with_state(&right_delta, &self.left_state, false, &mut next_key); - result.merge(&a_join_db); - - // ∂A ⋈ B (new left with existing right) - let da_join_b = - self.join_delta_with_state(&left_delta, &self.right_state, true, &mut next_key); - result.merge(&da_join_b); - - // ∂A ⋈ ∂B (new left with new right) - let da_join_db = self.join_deltas(&left_delta, &right_delta, &mut next_key); - result.merge(&da_join_db); - - // Update the next key counter - self.next_result_key = next_key; - - // Update state - self.left_state.merge(&left_delta); - self.right_state.merge(&right_delta); - self.current_state.merge(&result); - - // Rebuild indices if needed - self.left_index = Self::build_index( - &self.left_state, - &self.left_column_names, - &self.left_on_column, - ); - self.right_index = Self::build_index( - &self.right_state, - &self.right_column_names, - &self.right_on_column, - ); - - result - } - - pub fn get_current_state(&self) -> &Delta { - &self.current_state - } - - /// Process a delta from the left table only - pub fn process_left_delta(&mut self, left_delta: Delta) -> Delta { - let empty_delta = Delta::new(); - self.process_both_deltas(left_delta, empty_delta) - } - - /// Process a delta from the right table only - pub fn process_right_delta(&mut self, right_delta: Delta) -> Delta { - let empty_delta = Delta::new(); - self.process_both_deltas(empty_delta, right_delta) - } -} - /// Aggregate operator - performs incremental aggregation with GROUP BY /// Maintains running totals/counts that are updated incrementally #[derive(Debug, Clone)] @@ -1817,178 +1498,6 @@ mod tests { assert!(found_insertion, "Should have found insertion"); } - // Join tests - #[test] - fn test_join_uses_delta_formula() { - let tracker = Arc::new(Mutex::new(ComputationTracker::new())); - - // Create join operator - let mut join = JoinOperator::new( - JoinType::Inner, - "user_id".to_string(), - "user_id".to_string(), - vec!["user_id".to_string(), "email".to_string()], - vec![ - "login_id".to_string(), - "user_id".to_string(), - "timestamp".to_string(), - ], - ); - join.set_tracker(tracker.clone()); - - // Initial data: emails table - let mut emails = Delta::new(); - emails.insert( - 1, - vec![ - Value::Integer(1), - Value::Text(Text::new("alice@example.com")), - ], - ); - emails.insert( - 2, - vec![Value::Integer(2), Value::Text(Text::new("bob@example.com"))], - ); - - // Initial data: logins table - let mut logins = Delta::new(); - logins.insert( - 1, - vec![Value::Integer(1), Value::Integer(1), Value::Integer(1000)], - ); - logins.insert( - 2, - vec![Value::Integer(2), Value::Integer(1), Value::Integer(2000)], - ); - - // Initialize join - join.initialize_both(emails.clone(), logins.clone()); - - // Reset tracker for delta processing - tracker.lock().unwrap().join_lookups = 0; - - // Add one login for bob (user_id=2) - let mut delta_logins = Delta::new(); - delta_logins.insert( - 3, - vec![Value::Integer(3), Value::Integer(2), Value::Integer(3000)], - ); - - // Process delta - should use incremental formula - let empty_delta = Delta::new(); - let output = join.process_both_deltas(empty_delta, delta_logins); - - // Should have one join result (bob's new login) - assert_eq!(output.len(), 1); - - // Verify we used index lookups, not nested loops - // Should have done 1 lookup (finding bob's email for the new login) - let lookups = tracker.lock().unwrap().join_lookups; - assert_eq!(lookups, 1, "Should use index lookup, not scan all emails"); - - // Verify incremental behavior - we processed only the delta - let t = tracker.lock().unwrap(); - assert_incremental(&t, 1, 3); // 1 operation for 3 total rows - } - - #[test] - fn test_join_maintains_index() { - // Create join operator - let mut join = JoinOperator::new( - JoinType::Inner, - "id".to_string(), - "ref_id".to_string(), - vec!["id".to_string(), "name".to_string()], - vec!["ref_id".to_string(), "value".to_string()], - ); - - // Initial data - let mut left = Delta::new(); - left.insert(1, vec![Value::Integer(1), Value::Text(Text::new("A"))]); - left.insert(2, vec![Value::Integer(2), Value::Text(Text::new("B"))]); - - let mut right = Delta::new(); - right.insert(1, vec![Value::Integer(1), Value::Integer(100)]); - - // Initialize - should build index - join.initialize_both(left.clone(), right.clone()); - - // Verify initial join worked - let state = join.get_current_state(); - assert_eq!(state.changes.len(), 1); // One match: id=1 - - // Add new item to left - let mut delta_left = Delta::new(); - delta_left.insert(3, vec![Value::Integer(3), Value::Text(Text::new("C"))]); - - // Add matching item to right - let mut delta_right = Delta::new(); - delta_right.insert(2, vec![Value::Integer(3), Value::Integer(300)]); - - // Process deltas - let output = join.process_both_deltas(delta_left, delta_right); - - // Should have new join result - assert_eq!(output.len(), 1); - - // Verify the join result has the expected values - assert!(!output.changes.is_empty()); - let (result, _weight) = &output.changes[0]; - assert_eq!(result.values.len(), 4); // id, name, ref_id, value - } - - #[test] - fn test_join_formula_correctness() { - // Test the DBSP formula: ∂(A ⋈ B) = A ⋈ ∂B + ∂A ⋈ B + ∂A ⋈ ∂B - let tracker = Arc::new(Mutex::new(ComputationTracker::new())); - - let mut join = JoinOperator::new( - JoinType::Inner, - "x".to_string(), - "x".to_string(), - vec!["x".to_string(), "a".to_string()], - vec!["x".to_string(), "b".to_string()], - ); - join.set_tracker(tracker.clone()); - - // Initial state A - let mut a = Delta::new(); - a.insert(1, vec![Value::Integer(1), Value::Text(Text::new("a1"))]); - a.insert(2, vec![Value::Integer(2), Value::Text(Text::new("a2"))]); - - // Initial state B - let mut b = Delta::new(); - b.insert(1, vec![Value::Integer(1), Value::Text(Text::new("b1"))]); - b.insert(2, vec![Value::Integer(2), Value::Text(Text::new("b2"))]); - - join.initialize_both(a.clone(), b.clone()); - - // Reset tracker - tracker.lock().unwrap().join_lookups = 0; - - // Delta for A (add x=3) - let mut delta_a = Delta::new(); - delta_a.insert(3, vec![Value::Integer(3), Value::Text(Text::new("a3"))]); - - // Delta for B (add x=3 and x=1) - let mut delta_b = Delta::new(); - delta_b.insert(3, vec![Value::Integer(3), Value::Text(Text::new("b3"))]); - delta_b.insert(4, vec![Value::Integer(1), Value::Text(Text::new("b1_new"))]); - - let output = join.process_both_deltas(delta_a, delta_b); - - // Expected results: - // A ⋈ ∂B: (1,a1) ⋈ (1,b1_new) = 1 result - // ∂A ⋈ B: (3,a3) ⋈ nothing = 0 results - // ∂A ⋈ ∂B: (3,a3) ⋈ (3,b3) = 1 result - // Total: 2 results - assert_eq!(output.len(), 2); - - // Verify we're doing incremental work - let lookups = tracker.lock().unwrap().join_lookups; - assert!(lookups <= 4, "Should use efficient index lookups"); - } - // Aggregation tests #[test] fn test_count_increments_not_recounts() { From 898c0260f3cd92d4031c24626c449a8470c475ba Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 26 Aug 2025 21:32:37 -0500 Subject: [PATCH 6/9] move operator to eval / commit pattern We need a read only phase and a commit phase. Otherwise we will never be able to rollback changes properly. We currently do that, but we do that in the view. Before we move to circuits, this needs to be internalized by the operator. --- core/incremental/operator.rs | 592 ++++++++++++++++++++++++++++++++--- core/incremental/view.rs | 12 +- 2 files changed, 562 insertions(+), 42 deletions(-) diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 676a9a067..75366ed06 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -414,8 +414,18 @@ pub trait IncrementalOperator: Debug { /// Initialize with base data fn initialize(&mut self, data: Delta); - /// Process a delta (incremental update) - fn process_delta(&mut self, delta: Delta) -> Delta; + /// Evaluate the operator with a delta, without modifying internal state + /// This is used during query execution to compute results including uncommitted changes + /// + /// # Arguments + /// * `delta` - The committed delta to process + /// * `uncommitted` - Optional uncommitted changes from the current transaction + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta; + + /// Commit a delta to the operator's internal state and return the output + /// This is called when a transaction commits, making changes permanent + /// Returns the output delta (what downstream operators should see) + fn commit(&mut self, delta: Delta) -> Delta; /// Get current accumulated state fn get_current_state(&self) -> Delta; @@ -551,20 +561,51 @@ impl IncrementalOperator for FilterOperator { } } - fn process_delta(&mut self, delta: Delta) -> Delta { + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta { let mut output_delta = Delta::new(); - // Process only the delta, not the entire state + // Merge delta with uncommitted if present + let combined_delta = if let Some(uncommitted) = uncommitted { + let mut combined = delta; + combined.merge(&uncommitted); + combined + } else { + delta + }; + + // Process the combined delta through the filter + for (row, weight) in combined_delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_filter(); + } + + // Only pass through rows that satisfy the filter predicate + // For deletes (weight < 0), we only pass them if the row values + // would have passed the filter (meaning it was in the view) + if self.evaluate_predicate(&row.values) { + output_delta.changes.push((row, weight)); + } + } + + output_delta + } + + fn commit(&mut self, delta: Delta) -> Delta { + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state + // Only pass through and track rows that satisfy the filter predicate for (row, weight) in delta.changes { if let Some(tracker) = &self.tracker { tracker.lock().unwrap().record_filter(); } + // Only track and output rows that pass the filter + // For deletes, this means the row was in the view (its values pass the filter) + // For inserts, this means the row should be in the view if self.evaluate_predicate(&row.values) { - output_delta.changes.push((row.clone(), weight)); - - // Update our state - self.current_state.changes.push((row, weight)); + self.current_state.changes.push((row.clone(), weight)); + output_delta.changes.push((row, weight)); } } @@ -899,19 +940,45 @@ impl IncrementalOperator for ProjectOperator { } } - fn process_delta(&mut self, delta: Delta) -> Delta { + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta { let mut output_delta = Delta::new(); - for (row, weight) in &delta.changes { + // Merge delta with uncommitted if present + let combined_delta = if let Some(uncommitted) = uncommitted { + let mut combined = delta; + combined.merge(&uncommitted); + combined + } else { + delta + }; + + for (row, weight) in &combined_delta.changes { if let Some(tracker) = &self.tracker { tracker.lock().unwrap().record_project(); } let projected = self.project_values(&row.values); let projected_row = HashableRow::new(row.rowid, projected); + output_delta.changes.push((projected_row, *weight)); + } - output_delta.changes.push((projected_row.clone(), *weight)); - self.current_state.changes.push((projected_row, *weight)); + output_delta + } + + fn commit(&mut self, delta: Delta) -> Delta { + let mut output_delta = Delta::new(); + + // Commit the delta to our internal state and build output + for (row, weight) in &delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_project(); + } + let projected = self.project_values(&row.values); + let projected_row = HashableRow::new(row.rowid, projected); + self.current_state + .changes + .push((projected_row.clone(), *weight)); + output_delta.changes.push((projected_row, *weight)); } output_delta @@ -1212,11 +1279,100 @@ impl AggregateOperator { impl IncrementalOperator for AggregateOperator { fn initialize(&mut self, data: Delta) { - // Process all initial data - self.process_delta(data); + // Process all initial data - this modifies state during initialization + let _ = self.process_delta(data); } - fn process_delta(&mut self, delta: Delta) -> Delta { + fn eval(&self, delta: Delta, uncommitted: Option) -> Delta { + // Clone the current state to work with temporarily + let mut temp_group_states = self.group_states.clone(); + let mut temp_group_key_values = self.group_key_values.clone(); + + // Merge delta with uncommitted if present + let combined_delta = if let Some(uncommitted) = uncommitted { + let mut combined = delta; + combined.merge(&uncommitted); + combined + } else { + delta + }; + + let mut output_delta = Delta::new(); + let mut modified_groups = HashSet::new(); + let mut old_values: HashMap> = HashMap::new(); + + // Process each change in the combined delta using temporary state + for (row, weight) in &combined_delta.changes { + if let Some(tracker) = &self.tracker { + tracker.lock().unwrap().record_aggregation(); + } + + // Extract group key + let group_key = self.extract_group_key(&row.values); + let group_key_str = Self::group_key_to_string(&group_key); + + // Store old aggregate values BEFORE applying the delta + if !modified_groups.contains(&group_key_str) { + if let Some(state) = temp_group_states.get(&group_key_str) { + let mut old_row = group_key.clone(); + old_row.extend(state.to_values(&self.aggregates)); + old_values.insert(group_key_str.clone(), old_row); + } + } + + modified_groups.insert(group_key_str.clone()); + temp_group_key_values.insert(group_key_str.clone(), group_key.clone()); + + // Get or create aggregate state for this group in temporary state + let state = temp_group_states + .entry(group_key_str.clone()) + .or_insert_with(AggregateState::new); + + // Apply the delta to the temporary aggregate state + state.apply_delta( + &row.values, + *weight, + &self.aggregates, + &self.input_column_names, + ); + } + + // Generate output delta for modified groups using temporary state + for group_key_str in modified_groups { + let group_key = temp_group_key_values + .get(&group_key_str) + .cloned() + .unwrap_or_default(); + + // Generate a unique key for this group + let result_key = group_key_str + .bytes() + .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64)); + + // Emit retraction for old value if it existed + if let Some(old_row_values) = old_values.get(&group_key_str) { + let old_row = HashableRow::new(result_key, old_row_values.clone()); + output_delta.changes.push((old_row, -1)); + } + + if let Some(state) = temp_group_states.get(&group_key_str) { + // Build output row: group_by columns + aggregate values + let mut output_values = group_key.clone(); + output_values.extend(state.to_values(&self.aggregates)); + + // Check if group should be included (count > 0) + if state.count > 0 { + let output_row = HashableRow::new(result_key, output_values); + output_delta.changes.push((output_row, 1)); + } + } + } + + output_delta + } + + fn commit(&mut self, delta: Delta) -> Delta { + // Actually update the internal state when committing and return the output self.process_delta(delta) } @@ -1322,7 +1478,8 @@ mod tests { ); // Process the incremental update - let output_delta = agg.process_delta(update_delta); + let output_delta = agg.eval(update_delta.clone(), None); + agg.commit(update_delta); // CRITICAL: The output delta should contain TWO changes: // 1. Retraction of old aggregate value (90) with weight -1 @@ -1454,7 +1611,8 @@ mod tests { ); // Process the incremental update - let output_delta = agg.process_delta(update_delta); + let output_delta = agg.eval(update_delta.clone(), None); + agg.commit(update_delta); // Should have 2 changes: retraction of old red team sum, insertion of new red team sum // Blue team should NOT be affected @@ -1544,10 +1702,12 @@ mod tests { ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Should only update one group (cat_0), not recount all groups - assert_eq!(tracker.lock().unwrap().aggregation_updates, 1); + // Should update one group (cat_0) twice - once in eval, once in commit + // This is still incremental - we're not recounting all groups + assert_eq!(tracker.lock().unwrap().aggregation_updates, 2); // Check the final state - cat_0 should now have count 11 let final_state = agg.get_current_state(); @@ -1558,9 +1718,9 @@ mod tests { .unwrap(); assert_eq!(cat_0.0.values[1], Value::Integer(11)); - // Verify incremental behavior + // Verify incremental behavior - we process the delta twice (eval + commit) let t = tracker.lock().unwrap(); - assert_incremental(&t, 1, 101); + assert_incremental(&t, 2, 101); } #[test] @@ -1631,10 +1791,11 @@ mod tests { ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); - // Should only update Widget group - assert_eq!(tracker.lock().unwrap().aggregation_updates, 1); + // Should update Widget group twice (once in eval, once in commit) + assert_eq!(tracker.lock().unwrap().aggregation_updates, 2); // Check final state - Widget should now be 300 (250 + 50) let final_state = agg.get_current_state(); @@ -1708,7 +1869,8 @@ mod tests { 4, vec![Value::Integer(4), Value::Integer(1), Value::Integer(50)], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); // Check final state - user 1 should have updated count and sum let final_state = agg.get_current_state(); @@ -1792,7 +1954,8 @@ mod tests { Value::Integer(30), ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); // Check final state - Category A avg should now be (10 + 20 + 30) / 3 = 20 let final_state = agg.get_current_state(); @@ -1858,7 +2021,8 @@ mod tests { ], ); - let output = agg.process_delta(delta); + let _output = agg.eval(delta.clone(), None); + agg.commit(delta); // Check final state - should update to count=1, sum=200 let final_state = agg.get_current_state(); @@ -1909,7 +2073,8 @@ mod tests { let mut delete_delta = Delta::new(); delete_delta.delete(1, vec![Value::Text("A".into()), Value::Integer(10)]); - let output = agg.process_delta(delete_delta); + let output = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); // Should emit retraction for old count and insertion for new count assert_eq!(output.changes.len(), 2); @@ -1927,7 +2092,8 @@ mod tests { let mut delete_all_b = Delta::new(); delete_all_b.delete(3, vec![Value::Text("B".into()), Value::Integer(30)]); - let output_b = agg.process_delta(delete_all_b); + let output_b = agg.eval(delete_all_b.clone(), None); + agg.commit(delete_all_b); assert_eq!(output_b.changes.len(), 1); // Only retraction, no new row assert_eq!(output_b.changes[0].1, -1); // Retraction @@ -1973,7 +2139,8 @@ mod tests { let mut delete_delta = Delta::new(); delete_delta.delete(2, vec![Value::Text("A".into()), Value::Integer(20)]); - agg.process_delta(delete_delta); + let _ = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); // Check updated sum let state = agg.get_current_state(); @@ -1989,7 +2156,8 @@ mod tests { delete_all_b.delete(3, vec![Value::Text("B".into()), Value::Integer(30)]); delete_all_b.delete(4, vec![Value::Text("B".into()), Value::Integer(15)]); - agg.process_delta(delete_all_b); + let _ = agg.eval(delete_all_b.clone(), None); + agg.commit(delete_all_b); // Group B should be gone let final_state = agg.get_current_state(); @@ -2021,7 +2189,8 @@ mod tests { let mut delete_delta = Delta::new(); delete_delta.delete(2, vec![Value::Text("A".into()), Value::Integer(20)]); - agg.process_delta(delete_delta); + let _ = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); // Check updated average let state = agg.get_current_state(); @@ -2031,7 +2200,8 @@ mod tests { let mut delete_another = Delta::new(); delete_another.delete(3, vec![Value::Text("A".into()), Value::Integer(30)]); - agg.process_delta(delete_another); + let _ = agg.eval(delete_another.clone(), None); + agg.commit(delete_another); let state = agg.get_current_state(); assert_eq!(state.changes[0].0.values[1], Value::Float(10.0)); // AVG = 10/1 = 10 @@ -2073,7 +2243,8 @@ mod tests { let mut delete_delta = Delta::new(); delete_delta.delete(1, vec![Value::Text("A".into()), Value::Integer(100)]); - agg.process_delta(delete_delta); + let _ = agg.eval(delete_delta.clone(), None); + agg.commit(delete_delta); // Check all aggregates updated correctly let state = agg.get_current_state(); @@ -2091,7 +2262,8 @@ mod tests { let mut insert_delta = Delta::new(); insert_delta.insert(4, vec![Value::Text("A".into()), Value::Float(50.5)]); - agg.process_delta(insert_delta); + let _ = agg.eval(insert_delta.clone(), None); + agg.commit(insert_delta); let state = agg.get_current_state(); let group_a = state @@ -2138,7 +2310,8 @@ mod tests { update_delta.delete(3, vec![Value::Integer(3), Value::Integer(3)]); update_delta.insert(1, vec![Value::Integer(1), Value::Integer(3)]); - let output = filter.process_delta(update_delta); + let output = filter.eval(update_delta.clone(), None); + filter.commit(update_delta); // The output delta should have both changes (both pass the filter b > 2) assert_eq!(output.changes.len(), 2); @@ -2161,4 +2334,349 @@ mod tests { ); assert_eq!(final_state.changes[0].1, 1); // positive weight } + + // ============================================================================ + // EVAL/COMMIT PATTERN TESTS + // These tests verify that the eval/commit pattern works correctly: + // - eval() computes results without modifying state + // - eval() with uncommitted data returns correct results + // - commit() updates internal state + // - State remains unchanged when eval() is called with uncommitted data + // ============================================================================ + + #[test] + fn test_filter_eval_with_uncommitted() { + let mut filter = FilterOperator::new( + FilterPredicate::GreaterThan { + column: "age".to_string(), + value: Value::Integer(25), + }, + vec!["id".to_string(), "name".to_string(), "age".to_string()], + ); + + // Initialize with some data + let mut init_data = Delta::new(); + init_data.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + init_data.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(20), + ], + ); + filter.initialize(init_data); + + // Verify initial state (only Alice passes filter) + let state = filter.get_current_state(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.rowid, 1); + + // Create uncommitted changes + let mut uncommitted = Delta::new(); + uncommitted.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); + uncommitted.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(15), + ], + ); + + // Eval with uncommitted - should return filtered uncommitted rows + let result = filter.eval(Delta::new(), Some(uncommitted.clone())); + assert_eq!( + result.changes.len(), + 1, + "Only Charlie (35) should pass filter" + ); + assert_eq!(result.changes[0].0.rowid, 3); + + // Verify state hasn't changed + let state_after_eval = filter.get_current_state(); + assert_eq!( + state_after_eval.changes.len(), + 1, + "State should still only have Alice" + ); + assert_eq!(state_after_eval.changes[0].0.rowid, 1); + + // Now commit the changes + filter.commit(uncommitted); + + // State should now include Charlie (who passes filter) + let final_state = filter.get_current_state(); + assert_eq!( + final_state.changes.len(), + 2, + "State should now have Alice and Charlie" + ); + } + + #[test] + fn test_aggregate_eval_with_uncommitted_preserves_state() { + // This is the critical test - aggregations must not modify internal state during eval + let mut agg = AggregateOperator::new( + vec!["category".to_string()], + vec![ + AggregateFunction::Count, + AggregateFunction::Sum("amount".to_string()), + ], + vec![ + "id".to_string(), + "category".to_string(), + "amount".to_string(), + ], + ); + + // Initialize with data + let mut init_data = Delta::new(); + init_data.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("A".into()), + Value::Integer(100), + ], + ); + init_data.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("A".into()), + Value::Integer(200), + ], + ); + init_data.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("B".into()), + Value::Integer(150), + ], + ); + agg.initialize(init_data); + + // Check initial state: A -> (count=2, sum=300), B -> (count=1, sum=150) + let initial_state = agg.get_current_state(); + assert_eq!(initial_state.changes.len(), 2); + + // Store initial state for comparison + let initial_a = initial_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!(initial_a.0.values[1], Value::Integer(2)); // count + assert_eq!(initial_a.0.values[2], Value::Float(300.0)); // sum + + // Create uncommitted changes + let mut uncommitted = Delta::new(); + uncommitted.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("A".into()), + Value::Integer(50), + ], + ); + uncommitted.insert( + 5, + vec![ + Value::Integer(5), + Value::Text("C".into()), + Value::Integer(75), + ], + ); + + // Eval with uncommitted should return the delta (changes to aggregates) + let result = agg.eval(Delta::new(), Some(uncommitted.clone())); + + // Result should contain updates for A and new group C + // For A: retraction of old (2, 300) and insertion of new (3, 350) + // For C: insertion of (1, 75) + assert!(!result.changes.is_empty(), "Should have aggregate changes"); + + // CRITICAL: Verify internal state hasn't changed + let state_after_eval = agg.get_current_state(); + assert_eq!( + state_after_eval.changes.len(), + 2, + "State should still have only A and B" + ); + + let a_after_eval = state_after_eval + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!( + a_after_eval.0.values[1], + Value::Integer(2), + "A count should still be 2" + ); + assert_eq!( + a_after_eval.0.values[2], + Value::Float(300.0), + "A sum should still be 300" + ); + + // Now commit the changes + agg.commit(uncommitted); + + // State should now be updated + let final_state = agg.get_current_state(); + assert_eq!(final_state.changes.len(), 3, "Should now have A, B, and C"); + + let a_final = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("A".into())) + .unwrap(); + assert_eq!( + a_final.0.values[1], + Value::Integer(3), + "A count should now be 3" + ); + assert_eq!( + a_final.0.values[2], + Value::Float(350.0), + "A sum should now be 350" + ); + + let c_final = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("C".into())) + .unwrap(); + assert_eq!( + c_final.0.values[1], + Value::Integer(1), + "C count should be 1" + ); + assert_eq!( + c_final.0.values[2], + Value::Float(75.0), + "C sum should be 75" + ); + } + + #[test] + fn test_aggregate_eval_multiple_times_without_commit() { + // Test that calling eval multiple times with different uncommitted data + // doesn't pollute the internal state + let mut agg = AggregateOperator::new( + vec![], // No GROUP BY + vec![ + AggregateFunction::Count, + AggregateFunction::Sum("value".to_string()), + ], + vec!["id".to_string(), "value".to_string()], + ); + + // Initialize + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + init_data.insert(2, vec![Value::Integer(2), Value::Integer(200)]); + agg.initialize(init_data); + + // Initial state: count=2, sum=300 + let initial_state = agg.get_current_state(); + assert_eq!(initial_state.changes.len(), 1); + assert_eq!(initial_state.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(initial_state.changes[0].0.values[1], Value::Float(300.0)); + + // First eval with uncommitted + let mut uncommitted1 = Delta::new(); + uncommitted1.insert(3, vec![Value::Integer(3), Value::Integer(50)]); + let _ = agg.eval(Delta::new(), Some(uncommitted1)); + + // State should be unchanged + let state1 = agg.get_current_state(); + assert_eq!(state1.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(state1.changes[0].0.values[1], Value::Float(300.0)); + + // Second eval with different uncommitted + let mut uncommitted2 = Delta::new(); + uncommitted2.insert(4, vec![Value::Integer(4), Value::Integer(75)]); + uncommitted2.insert(5, vec![Value::Integer(5), Value::Integer(25)]); + let _ = agg.eval(Delta::new(), Some(uncommitted2)); + + // State should STILL be unchanged + let state2 = agg.get_current_state(); + assert_eq!(state2.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(state2.changes[0].0.values[1], Value::Float(300.0)); + + // Third eval with deletion as uncommitted + let mut uncommitted3 = Delta::new(); + uncommitted3.delete(1, vec![Value::Integer(1), Value::Integer(100)]); + let _ = agg.eval(Delta::new(), Some(uncommitted3)); + + // State should STILL be unchanged + let state3 = agg.get_current_state(); + assert_eq!(state3.changes[0].0.values[0], Value::Integer(2)); + assert_eq!(state3.changes[0].0.values[1], Value::Float(300.0)); + } + + #[test] + fn test_aggregate_eval_with_mixed_committed_and_uncommitted() { + // Test eval with both committed delta and uncommitted changes + let mut agg = AggregateOperator::new( + vec!["type".to_string()], + vec![AggregateFunction::Count], + vec!["id".to_string(), "type".to_string()], + ); + + // Initialize + let mut init_data = Delta::new(); + init_data.insert(1, vec![Value::Integer(1), Value::Text("X".into())]); + init_data.insert(2, vec![Value::Integer(2), Value::Text("Y".into())]); + agg.initialize(init_data); + + // Create a committed delta (to be processed) + let mut committed_delta = Delta::new(); + committed_delta.insert(3, vec![Value::Integer(3), Value::Text("X".into())]); + + // Create uncommitted changes + let mut uncommitted = Delta::new(); + uncommitted.insert(4, vec![Value::Integer(4), Value::Text("Y".into())]); + uncommitted.insert(5, vec![Value::Integer(5), Value::Text("Z".into())]); + + // Eval with both - should process both but not commit + let result = agg.eval(committed_delta.clone(), Some(uncommitted)); + + // Result should reflect changes from both + assert!(!result.changes.is_empty()); + + // But internal state should be unchanged + let state = agg.get_current_state(); + assert_eq!(state.changes.len(), 2, "Should still have only X and Y"); + + // Now commit only the committed_delta + agg.commit(committed_delta); + + // State should now have X count=2, Y count=1 + let final_state = agg.get_current_state(); + let x = final_state + .changes + .iter() + .find(|(row, _)| row.values[0] == Value::Text("X".into())) + .unwrap(); + assert_eq!(x.0.values[1], Value::Integer(2)); + } } diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 4f4d4c6e6..7033fe83c 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -767,19 +767,21 @@ impl IncrementalView { } } - /// Apply filter operator to a delta if present + /// Apply filter operator to a delta if present and commit the changes fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta { if let Some(ref mut filter_op) = self.filter_operator { - filter_op.process_delta(delta) + // Commit updates state and returns output + filter_op.commit(delta) } else { delta } } - /// Apply aggregation operator to a delta if this is an aggregated view + /// Apply aggregation operator to a delta if this is an aggregated view and commit the changes fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta { if let Some(ref mut agg_op) = self.aggregate_operator { - agg_op.process_delta(delta) + // Commit updates state and returns output + agg_op.commit(delta) } else { delta } @@ -798,7 +800,7 @@ impl IncrementalView { // Apply projection operator if present (for non-aggregated views) if let Some(ref mut project_op) = self.project_operator { - current_delta = project_op.process_delta(current_delta); + current_delta = project_op.commit(current_delta); } current_delta = self.apply_aggregation_to_delta(current_delta); From 29b93e3e58404c5f4d231c934d016cc2eb0844be Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Mon, 25 Aug 2025 21:33:05 -0500 Subject: [PATCH 7/9] add DBSP circuit compiler The next step is to adapt the view code to use circuits instead of listing the operators manually. --- core/incremental/compiler.rs | 2922 ++++++++++++++++++++++++++++++++++ core/incremental/mod.rs | 1 + core/incremental/operator.rs | 42 +- core/translate/mod.rs | 1 - 4 files changed, 2964 insertions(+), 2 deletions(-) create mode 100644 core/incremental/compiler.rs diff --git a/core/incremental/compiler.rs b/core/incremental/compiler.rs new file mode 100644 index 000000000..9cd2e3702 --- /dev/null +++ b/core/incremental/compiler.rs @@ -0,0 +1,2922 @@ +//! DBSP Compiler: Converts Logical Plans to DBSP Circuits +//! +//! This module implements compilation from SQL logical plans to DBSP circuits. +//! The initial version supports only filter and projection operators. +//! +//! Based on the DBSP paper: "DBSP: Automatic Incremental View Maintenance for Rich Query Languages" + +use crate::incremental::expr_compiler::CompiledExpression; +use crate::incremental::operator::{ + Delta, FilterOperator, FilterPredicate, IncrementalOperator, ProjectOperator, +}; +// Note: logical module must be made pub(crate) in translate/mod.rs +use crate::translate::logical::{BinaryOperator, LogicalExpr, LogicalPlan, SchemaRef}; +use crate::types::Value; +use crate::{LimboError, Result}; +use std::collections::HashMap; +use std::fmt::{self, Display, Formatter}; + +/// A set of deltas for multiple tables/operators +/// This provides a cleaner API for passing deltas through circuit execution +#[derive(Debug, Clone, Default)] +pub struct DeltaSet { + /// Deltas keyed by table/operator name + deltas: HashMap, +} + +impl DeltaSet { + /// Create a new empty delta set + pub fn new() -> Self { + Self { + deltas: HashMap::new(), + } + } + + /// Create an empty delta set (more semantic for "no changes") + pub fn empty() -> Self { + Self { + deltas: HashMap::new(), + } + } + + /// Add a delta for a table + pub fn insert(&mut self, table_name: String, delta: Delta) { + self.deltas.insert(table_name, delta); + } + + /// Get delta for a table, returns empty delta if not found + pub fn get(&self, table_name: &str) -> Delta { + self.deltas + .get(table_name) + .cloned() + .unwrap_or_else(Delta::new) + } +} + +/// Represents a DBSP operator in the compiled circuit +#[derive(Debug, Clone, PartialEq)] +pub enum DbspOperator { + /// Filter operator (σ) - filters records based on a predicate + Filter { predicate: DbspExpr }, + /// Projection operator (Ï€) - projects specific columns + Projection { + exprs: Vec, + schema: SchemaRef, + }, + /// Aggregate operator (γ) - performs grouping and aggregation + Aggregate { + group_exprs: Vec, + aggr_exprs: Vec, + schema: SchemaRef, + }, + /// Input operator - source of data + Input { name: String, schema: SchemaRef }, +} + +/// Represents an expression in DBSP +#[derive(Debug, Clone, PartialEq)] +pub enum DbspExpr { + /// Column reference + Column(String), + /// Literal value + Literal(Value), + /// Binary expression + BinaryExpr { + left: Box, + op: BinaryOperator, + right: Box, + }, +} + +/// A node in the DBSP circuit DAG +pub struct DbspNode { + /// Unique identifier for this node + pub id: usize, + /// The operator metadata + pub operator: DbspOperator, + /// Input nodes (edges in the DAG) + pub inputs: Vec, + /// The actual executable operator (if applicable) + pub executable: Option>, +} + +impl std::fmt::Debug for DbspNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DbspNode") + .field("id", &self.id) + .field("operator", &self.operator) + .field("inputs", &self.inputs) + .field("has_executable", &self.executable.is_some()) + .finish() + } +} + +/// Represents a complete DBSP circuit (DAG of operators) +#[derive(Debug)] +pub struct DbspCircuit { + /// All nodes in the circuit, indexed by their ID + pub(super) nodes: HashMap, + /// Counter for generating unique node IDs + next_id: usize, + /// Root node ID (the final output) + pub(super) root: Option, +} + +impl DbspCircuit { + /// Create a new empty circuit + pub fn new() -> Self { + Self { + nodes: HashMap::new(), + next_id: 0, + root: None, + } + } + + /// Add a node to the circuit + fn add_node( + &mut self, + operator: DbspOperator, + inputs: Vec, + executable: Option>, + ) -> usize { + let id = self.next_id; + self.next_id += 1; + + let node = DbspNode { + id, + operator, + inputs, + executable, + }; + + self.nodes.insert(id, node); + id + } + + /// Initialize the circuit with base data. Should be called once before processing deltas. + /// If the database is restarting with materialized views, this can be skipped. + pub fn initialize(&mut self, input_data: HashMap) -> Result { + if let Some(root_id) = self.root { + self.initialize_node(root_id, &input_data) + } else { + Err(LimboError::ParseError( + "Circuit has no root node".to_string(), + )) + } + } + + /// Initialize a specific node and its dependencies + fn initialize_node( + &mut self, + node_id: usize, + input_data: &HashMap, + ) -> Result { + // Clone to avoid borrow checker issues + let inputs = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))? + .inputs + .clone(); + + // Initialize inputs first + let mut input_deltas = Vec::new(); + for input_id in inputs { + let delta = self.initialize_node(input_id, input_data)?; + input_deltas.push(delta); + } + + // Get mutable reference to node + let node = self + .nodes + .get_mut(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))?; + + // Initialize based on operator type + let result = match &node.operator { + DbspOperator::Input { name, .. } => { + // Get data from input map + input_data.get(name).cloned().unwrap_or_else(Delta::new) + } + DbspOperator::Filter { .. } + | DbspOperator::Projection { .. } + | DbspOperator::Aggregate { .. } => { + // Initialize the executable operator + if let Some(ref mut op) = node.executable { + if !input_deltas.is_empty() { + let input_delta = input_deltas[0].clone(); + op.initialize(input_delta); + op.get_current_state() + } else { + Delta::new() + } + } else { + // If no executable, pass through the input + if !input_deltas.is_empty() { + input_deltas[0].clone() + } else { + Delta::new() + } + } + } + }; + + Ok(result) + } + + /// Execute the circuit with incremental input data (deltas). + /// Call initialize() first for initial data, then use execute() for updates. + /// + /// # Arguments + /// * `input_data` - The committed deltas to process + /// * `uncommitted_data` - Uncommitted transaction deltas that should be visible + /// during this execution but not stored in operators. + /// Use DeltaSet::empty() for no uncommitted changes. + pub fn execute( + &self, + input_data: HashMap, + uncommitted_data: DeltaSet, + ) -> Result { + if let Some(root_id) = self.root { + self.execute_node(root_id, &input_data, &uncommitted_data) + } else { + Err(LimboError::ParseError( + "Circuit has no root node".to_string(), + )) + } + } + + /// Commit deltas to the circuit, updating internal operator state. + /// This should be called after execute() when you want to make changes permanent. + /// + /// # Arguments + /// * `input_data` - The deltas to commit (same as what was passed to execute) + pub fn commit(&mut self, input_data: HashMap) -> Result<()> { + if let Some(root_id) = self.root { + self.commit_node(root_id, &input_data)?; + } + Ok(()) + } + + /// Commit a specific node in the circuit + fn commit_node( + &mut self, + node_id: usize, + input_data: &HashMap, + ) -> Result { + // Clone to avoid borrow checker issues + let inputs = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))? + .inputs + .clone(); + + // Process inputs first + let mut input_deltas = Vec::new(); + for input_id in inputs { + let delta = self.commit_node(input_id, input_data)?; + input_deltas.push(delta); + } + + // Get mutable reference to node + let node = self + .nodes + .get_mut(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))?; + + // Commit based on operator type + let result = match &node.operator { + DbspOperator::Input { name, .. } => { + // For input nodes, just return the committed delta + input_data.get(name).cloned().unwrap_or_else(Delta::new) + } + DbspOperator::Filter { .. } + | DbspOperator::Projection { .. } + | DbspOperator::Aggregate { .. } => { + // Commit the delta to the executable operator + if let Some(ref mut op) = node.executable { + if !input_deltas.is_empty() { + let input_delta = input_deltas[0].clone(); + // Commit updates state and returns the output delta + op.commit(input_delta) + } else { + Delta::new() + } + } else { + // If no executable, pass through the input + if !input_deltas.is_empty() { + input_deltas[0].clone() + } else { + Delta::new() + } + } + } + }; + Ok(result) + } + + /// Execute a specific node in the circuit + fn execute_node( + &self, + node_id: usize, + input_data: &HashMap, + uncommitted_data: &DeltaSet, + ) -> Result { + // Clone to avoid borrow checker issues + let inputs = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))? + .inputs + .clone(); + + // Process inputs first + let mut input_deltas = Vec::new(); + for input_id in inputs { + let delta = self.execute_node(input_id, input_data, uncommitted_data)?; + input_deltas.push(delta); + } + + // Get reference to node (read-only since we're using eval, not commit) + let node = self + .nodes + .get(&node_id) + .ok_or_else(|| LimboError::ParseError("Node not found".to_string()))?; + + // Execute based on operator type + let result = match &node.operator { + DbspOperator::Input { name, .. } => { + // Get committed data from input map and merge with uncommitted if present + let committed = input_data.get(name).cloned().unwrap_or_else(Delta::new); + let uncommitted = uncommitted_data.get(name); + + // If there's uncommitted data for this table, merge it with committed + if !uncommitted.is_empty() { + let mut combined = committed; + combined.merge(&uncommitted); + combined + } else { + committed + } + } + DbspOperator::Filter { .. } + | DbspOperator::Projection { .. } + | DbspOperator::Aggregate { .. } => { + // Process delta using the executable operator + if let Some(ref op) = node.executable { + if !input_deltas.is_empty() { + // Process the delta through the operator + let input_delta = input_deltas[0].clone(); + + // Use eval to compute result without modifying state + // The uncommitted data has already been merged into input_delta if needed + op.eval(input_delta, None) + } else { + Delta::new() + } + } else { + // If no executable, pass through the input + if !input_deltas.is_empty() { + input_deltas[0].clone() + } else { + Delta::new() + } + } + } + }; + Ok(result) + } +} + +impl Display for DbspCircuit { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + writeln!(f, "DBSP Circuit:")?; + if let Some(root_id) = self.root { + self.fmt_node(f, root_id, 0)?; + } + Ok(()) + } +} + +impl DbspCircuit { + fn fmt_node(&self, f: &mut Formatter, node_id: usize, depth: usize) -> fmt::Result { + let indent = " ".repeat(depth); + if let Some(node) = self.nodes.get(&node_id) { + match &node.operator { + DbspOperator::Filter { predicate } => { + writeln!(f, "{indent}Filter[{node_id}]: {predicate:?}")?; + } + DbspOperator::Projection { exprs, .. } => { + writeln!(f, "{indent}Projection[{node_id}]: {exprs:?}")?; + } + DbspOperator::Aggregate { + group_exprs, + aggr_exprs, + .. + } => { + writeln!( + f, + "{indent}Aggregate[{node_id}]: GROUP BY {group_exprs:?}, AGGR {aggr_exprs:?}" + )?; + } + DbspOperator::Input { name, .. } => { + writeln!(f, "{indent}Input[{node_id}]: {name}")?; + } + } + + for input_id in &node.inputs { + self.fmt_node(f, *input_id, depth + 1)?; + } + } + Ok(()) + } +} + +/// Compiler from LogicalPlan to DBSP Circuit +pub struct DbspCompiler { + circuit: DbspCircuit, +} + +impl DbspCompiler { + /// Create a new DBSP compiler + pub fn new() -> Self { + Self { + circuit: DbspCircuit::new(), + } + } + + /// Compile a logical plan to a DBSP circuit + pub fn compile(mut self, plan: &LogicalPlan) -> Result { + let root_id = self.compile_plan(plan)?; + self.circuit.root = Some(root_id); + Ok(self.circuit) + } + + /// Recursively compile a logical plan node + fn compile_plan(&mut self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Projection(proj) => { + // Compile the input first + let input_id = self.compile_plan(&proj.input)?; + + // Get input column names for the ProjectOperator + let input_schema = proj.input.schema(); + let input_column_names: Vec = input_schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Convert logical expressions to DBSP expressions + let dbsp_exprs = proj.exprs.iter() + .map(Self::compile_expr) + .collect::>>()?; + + // Compile logical expressions to CompiledExpressions + let mut compiled_exprs = Vec::new(); + let mut aliases = Vec::new(); + for expr in &proj.exprs { + let (compiled, alias) = Self::compile_expression(expr, &input_column_names)?; + compiled_exprs.push(compiled); + aliases.push(alias); + } + + // Get output column names from the projection schema + let output_column_names: Vec = proj.schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Create the ProjectOperator + let executable: Option> = + ProjectOperator::from_compiled(compiled_exprs, aliases, input_column_names, output_column_names) + .ok() + .map(|op| Box::new(op) as Box); + + // Create projection node + let node_id = self.circuit.add_node( + DbspOperator::Projection { + exprs: dbsp_exprs, + schema: proj.schema.clone(), + }, + vec![input_id], + executable, + ); + Ok(node_id) + } + LogicalPlan::Filter(filter) => { + // Compile the input first + let input_id = self.compile_plan(&filter.input)?; + + // Get column names from input schema + let input_schema = filter.input.schema(); + let column_names: Vec = input_schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Convert predicate to DBSP expression + let dbsp_predicate = Self::compile_expr(&filter.predicate)?; + + // Convert to FilterPredicate + let filter_predicate = Self::compile_filter_predicate(&filter.predicate)?; + + // Create executable operator + let executable: Box = + Box::new(FilterOperator::new(filter_predicate, column_names)); + + // Create filter node + let node_id = self.circuit.add_node( + DbspOperator::Filter { predicate: dbsp_predicate }, + vec![input_id], + Some(executable), + ); + Ok(node_id) + } + LogicalPlan::Aggregate(agg) => { + // Compile the input first + let input_id = self.compile_plan(&agg.input)?; + + // Get input column names + let input_schema = agg.input.schema(); + let input_column_names: Vec = input_schema.columns.iter() + .map(|(name, _)| name.clone()) + .collect(); + + // Compile group by expressions to column names + let mut group_by_columns = Vec::new(); + let mut dbsp_group_exprs = Vec::new(); + for expr in &agg.group_expr { + // For now, only support simple column references in GROUP BY + if let LogicalExpr::Column(col) = expr { + group_by_columns.push(col.name.clone()); + dbsp_group_exprs.push(DbspExpr::Column(col.name.clone())); + } else { + return Err(LimboError::ParseError( + "Only column references are supported in GROUP BY for incremental views".to_string() + )); + } + } + + // Compile aggregate expressions + let mut aggregate_functions = Vec::new(); + for expr in &agg.aggr_expr { + if let LogicalExpr::AggregateFunction { fun, args, .. } = expr { + use crate::function::AggFunc; + use crate::incremental::operator::AggregateFunction; + + let agg_fn = match fun { + AggFunc::Count | AggFunc::Count0 => { + AggregateFunction::Count + } + AggFunc::Sum => { + if args.is_empty() { + return Err(LimboError::ParseError("SUM requires an argument".to_string())); + } + // Extract column name from the argument + if let LogicalExpr::Column(col) = &args[0] { + AggregateFunction::Sum(col.name.clone()) + } else { + return Err(LimboError::ParseError( + "Only column references are supported in aggregate functions for incremental views".to_string() + )); + } + } + AggFunc::Avg => { + if args.is_empty() { + return Err(LimboError::ParseError("AVG requires an argument".to_string())); + } + if let LogicalExpr::Column(col) = &args[0] { + AggregateFunction::Avg(col.name.clone()) + } else { + return Err(LimboError::ParseError( + "Only column references are supported in aggregate functions for incremental views".to_string() + )); + } + } + // MIN and MAX are not supported in incremental views due to storage overhead. + // To correctly handle deletions, these operators would need to track all values + // in each group, resulting in O(n) storage overhead. This is prohibitive for + // large datasets. Alternative approaches like maintaining sorted indexes still + // require O(n) storage. Until a more efficient solution is found, MIN/MAX + // aggregations are not supported in materialized views. + AggFunc::Min => { + return Err(LimboError::ParseError( + "MIN aggregation is not supported in incremental materialized views due to O(n) storage overhead required for handling deletions".to_string() + )); + } + AggFunc::Max => { + return Err(LimboError::ParseError( + "MAX aggregation is not supported in incremental materialized views due to O(n) storage overhead required for handling deletions".to_string() + )); + } + _ => { + return Err(LimboError::ParseError( + format!("Unsupported aggregate function in DBSP compiler: {fun:?}") + )); + } + }; + aggregate_functions.push(agg_fn); + } else { + return Err(LimboError::ParseError( + "Expected aggregate function in aggregate expressions".to_string() + )); + } + } + + // Create the AggregateOperator + use crate::incremental::operator::AggregateOperator; + let executable: Option> = Some( + Box::new(AggregateOperator::new( + group_by_columns, + aggregate_functions.clone(), + input_column_names, + )) + ); + + // Create aggregate node + let node_id = self.circuit.add_node( + DbspOperator::Aggregate { + group_exprs: dbsp_group_exprs, + aggr_exprs: aggregate_functions, + schema: agg.schema.clone(), + }, + vec![input_id], + executable, + ); + Ok(node_id) + } + LogicalPlan::TableScan(scan) => { + // Create input node (no executable needed for input) + let node_id = self.circuit.add_node( + DbspOperator::Input { + name: scan.table_name.clone(), + schema: scan.schema.clone(), + }, + vec![], + None, + ); + Ok(node_id) + } + _ => Err(LimboError::ParseError( + format!("Unsupported operator in DBSP compiler: only Filter, Projection and Aggregate are supported, got: {:?}", + match plan { + LogicalPlan::Sort(_) => "Sort", + LogicalPlan::Limit(_) => "Limit", + LogicalPlan::Union(_) => "Union", + LogicalPlan::Distinct(_) => "Distinct", + LogicalPlan::EmptyRelation(_) => "EmptyRelation", + LogicalPlan::Values(_) => "Values", + LogicalPlan::WithCTE(_) => "WithCTE", + LogicalPlan::CTERef(_) => "CTERef", + _ => "Unknown", + } + ) + )), + } + } + + /// Convert a logical expression to a DBSP expression + fn compile_expr(expr: &LogicalExpr) -> Result { + match expr { + LogicalExpr::Column(col) => Ok(DbspExpr::Column(col.name.clone())), + + LogicalExpr::Literal(val) => Ok(DbspExpr::Literal(val.clone())), + + LogicalExpr::BinaryExpr { left, op, right } => { + let left_expr = Self::compile_expr(left)?; + let right_expr = Self::compile_expr(right)?; + + Ok(DbspExpr::BinaryExpr { + left: Box::new(left_expr), + op: *op, + right: Box::new(right_expr), + }) + } + + LogicalExpr::Alias { expr, .. } => { + // For aliases, compile the underlying expression + Self::compile_expr(expr) + } + + // For complex expressions (functions, etc), we can't represent them as DbspExpr + // but that's OK - they'll be handled by the ProjectOperator's VDBE compilation + // For now, just use a placeholder + _ => { + // Use a literal null as placeholder - the actual execution will use the compiled VDBE + Ok(DbspExpr::Literal(Value::Null)) + } + } + } + + /// Compile a logical expression to a CompiledExpression and optional alias + fn compile_expression( + expr: &LogicalExpr, + input_column_names: &[String], + ) -> Result<(CompiledExpression, Option)> { + // Check for alias first + if let LogicalExpr::Alias { expr, alias } = expr { + // For aliases, compile the underlying expression and return with alias + let (compiled, _) = Self::compile_expression(expr, input_column_names)?; + return Ok((compiled, Some(alias.clone()))); + } + + // Convert LogicalExpr to AST Expr + let ast_expr = Self::logical_to_ast_expr(expr)?; + + // For all expressions (simple or complex), use CompiledExpression::compile + // This handles both trivial cases and complex VDBE compilation + // We need to set up the necessary context + use crate::{Database, MemoryIO, SymbolTable}; + use std::sync::Arc; + + // Create an internal connection for expression compilation + let io = Arc::new(MemoryIO::new()); + let db = Database::open_file(io, ":memory:", false, false)?; + let internal_conn = db.connect()?; + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + // Create temporary symbol table + let temp_syms = SymbolTable::new(); + + // Get a minimal schema for compilation (we don't need the full schema for expressions) + let schema = crate::schema::Schema::new(false); + + // Compile the expression using the existing CompiledExpression::compile + let compiled = CompiledExpression::compile( + &ast_expr, + input_column_names, + &schema, + &temp_syms, + internal_conn, + )?; + + Ok((compiled, None)) + } + + /// Convert LogicalExpr to AST Expr + fn logical_to_ast_expr(expr: &LogicalExpr) -> Result { + use turso_parser::ast; + + match expr { + LogicalExpr::Column(col) => Ok(ast::Expr::Id(ast::Name::Ident(col.name.clone()))), + LogicalExpr::Literal(val) => { + let lit = match val { + Value::Integer(i) => ast::Literal::Numeric(i.to_string()), + Value::Float(f) => ast::Literal::Numeric(f.to_string()), + Value::Text(t) => ast::Literal::String(t.to_string()), + Value::Blob(b) => ast::Literal::Blob(format!("{b:?}")), + Value::Null => ast::Literal::Null, + }; + Ok(ast::Expr::Literal(lit)) + } + LogicalExpr::BinaryExpr { left, op, right } => { + let left_expr = Self::logical_to_ast_expr(left)?; + let right_expr = Self::logical_to_ast_expr(right)?; + Ok(ast::Expr::Binary( + Box::new(left_expr), + *op, + Box::new(right_expr), + )) + } + LogicalExpr::ScalarFunction { fun, args } => { + let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); + Ok(ast::Expr::FunctionCall { + name: ast::Name::Ident(fun.clone()), + distinctness: None, + args: ast_args, + order_by: Vec::new(), + filter_over: ast::FunctionTail { + filter_clause: None, + over_clause: None, + }, + }) + } + LogicalExpr::Alias { expr, .. } => { + // For conversion to AST, ignore the alias and convert the inner expression + Self::logical_to_ast_expr(expr) + } + LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } => { + // Convert aggregate function to AST + let ast_args: Result> = args.iter().map(Self::logical_to_ast_expr).collect(); + let ast_args: Vec> = ast_args?.into_iter().map(Box::new).collect(); + + // Get the function name based on the aggregate type + let func_name = match fun { + crate::function::AggFunc::Count => "COUNT", + crate::function::AggFunc::Sum => "SUM", + crate::function::AggFunc::Avg => "AVG", + crate::function::AggFunc::Min => "MIN", + crate::function::AggFunc::Max => "MAX", + _ => { + return Err(LimboError::ParseError(format!( + "Unsupported aggregate function: {fun:?}" + ))) + } + }; + + Ok(ast::Expr::FunctionCall { + name: ast::Name::Ident(func_name.to_string()), + distinctness: if *distinct { + Some(ast::Distinctness::Distinct) + } else { + None + }, + args: ast_args, + order_by: Vec::new(), + filter_over: ast::FunctionTail { + filter_clause: None, + over_clause: None, + }, + }) + } + _ => Err(LimboError::ParseError(format!( + "Cannot convert LogicalExpr to AST Expr: {expr:?}" + ))), + } + } + + /// Compile a logical expression to a FilterPredicate for execution + fn compile_filter_predicate(expr: &LogicalExpr) -> Result { + match expr { + LogicalExpr::BinaryExpr { left, op, right } => { + // Extract column name and value for simple predicates + if let (LogicalExpr::Column(col), LogicalExpr::Literal(val)) = + (left.as_ref(), right.as_ref()) + { + match op { + BinaryOperator::Equals => Ok(FilterPredicate::Equals { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::NotEquals => Ok(FilterPredicate::NotEquals { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::Greater => Ok(FilterPredicate::GreaterThan { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::GreaterEquals => Ok(FilterPredicate::GreaterThanOrEqual { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::Less => Ok(FilterPredicate::LessThan { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::LessEquals => Ok(FilterPredicate::LessThanOrEqual { + column: col.name.clone(), + value: val.clone(), + }), + BinaryOperator::And => { + // Handle AND of two predicates + let left_pred = Self::compile_filter_predicate(left)?; + let right_pred = Self::compile_filter_predicate(right)?; + Ok(FilterPredicate::And( + Box::new(left_pred), + Box::new(right_pred), + )) + } + BinaryOperator::Or => { + // Handle OR of two predicates + let left_pred = Self::compile_filter_predicate(left)?; + let right_pred = Self::compile_filter_predicate(right)?; + Ok(FilterPredicate::Or( + Box::new(left_pred), + Box::new(right_pred), + )) + } + _ => Err(LimboError::ParseError(format!( + "Unsupported operator in filter: {op:?}" + ))), + } + } else if matches!(op, BinaryOperator::And | BinaryOperator::Or) { + // Handle logical operators + let left_pred = Self::compile_filter_predicate(left)?; + let right_pred = Self::compile_filter_predicate(right)?; + match op { + BinaryOperator::And => Ok(FilterPredicate::And( + Box::new(left_pred), + Box::new(right_pred), + )), + BinaryOperator::Or => Ok(FilterPredicate::Or( + Box::new(left_pred), + Box::new(right_pred), + )), + _ => unreachable!(), + } + } else { + Err(LimboError::ParseError( + "Filter predicate must be column op value".to_string(), + )) + } + } + _ => Err(LimboError::ParseError(format!( + "Unsupported filter expression: {expr:?}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::incremental::operator::{Delta, FilterOperator, FilterPredicate}; + use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; + use crate::translate::logical::LogicalPlanBuilder; + use crate::translate::logical::LogicalSchema; + use std::sync::Arc; + use turso_parser::ast; + use turso_parser::parser::Parser; + + // Macro to create a test schema with a users table + macro_rules! test_schema { + () => {{ + let mut schema = Schema::new(false); + let users_table = BTreeTable { + name: "users".to_string(), + root_page: 2, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("age".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(users_table)); + schema + }}; + } + + // Macro to compile SQL to DBSP circuit + macro_rules! compile_sql { + ($sql:expr) => {{ + let schema = test_schema!(); + let mut parser = Parser::new($sql.as_bytes()); + let cmd = parser + .next() + .unwrap() // This returns Option> + .unwrap(); // This unwraps the Result + + match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + let logical_plan = builder.build_statement(&stmt).unwrap(); + DbspCompiler::new().compile(&logical_plan).unwrap() + } + _ => panic!("Only SQL statements are supported"), + } + }}; + } + + // Macro to assert circuit structure + macro_rules! assert_circuit { + ($circuit:expr, depth: $depth:expr, root: $root_type:ident) => { + assert_eq!($circuit.nodes.len(), $depth); + let node = get_node_at_level(&$circuit, 0); + assert!(matches!(node.operator, DbspOperator::$root_type { .. })); + }; + } + + // Macro to assert operator properties + macro_rules! assert_operator { + ($circuit:expr, $level:expr, Input { name: $name:expr }) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Input { name, .. } => assert_eq!(name, $name), + _ => panic!("Expected Input operator at level {}", $level), + } + }}; + ($circuit:expr, $level:expr, Filter) => {{ + let node = get_node_at_level(&$circuit, $level); + assert!(matches!(node.operator, DbspOperator::Filter { .. })); + }}; + ($circuit:expr, $level:expr, Projection { columns: [$($col:expr),*] }) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Projection { exprs, .. } => { + let expected_cols = vec![$($col),*]; + let actual_cols: Vec = exprs.iter().map(|e| { + match e { + DbspExpr::Column(name) => name.clone(), + _ => "expr".to_string(), + } + }).collect(); + assert_eq!(actual_cols, expected_cols); + } + _ => panic!("Expected Projection operator at level {}", $level), + } + }}; + } + + // Macro to assert filter predicate + macro_rules! assert_filter_predicate { + ($circuit:expr, $level:expr, $col:literal > $val:literal) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Filter { predicate } => match predicate { + DbspExpr::BinaryExpr { left, op, right } => { + assert!(matches!(op, ast::Operator::Greater)); + assert!(matches!(&**left, DbspExpr::Column(name) if name == $col)); + assert!(matches!(&**right, DbspExpr::Literal(Value::Integer($val)))); + } + _ => panic!("Expected binary expression in filter"), + }, + _ => panic!("Expected Filter operator at level {}", $level), + } + }}; + ($circuit:expr, $level:expr, $col:literal < $val:literal) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Filter { predicate } => match predicate { + DbspExpr::BinaryExpr { left, op, right } => { + assert!(matches!(op, ast::Operator::Less)); + assert!(matches!(&**left, DbspExpr::Column(name) if name == $col)); + assert!(matches!(&**right, DbspExpr::Literal(Value::Integer($val)))); + } + _ => panic!("Expected binary expression in filter"), + }, + _ => panic!("Expected Filter operator at level {}", $level), + } + }}; + ($circuit:expr, $level:expr, $col:literal = $val:literal) => {{ + let node = get_node_at_level(&$circuit, $level); + match &node.operator { + DbspOperator::Filter { predicate } => match predicate { + DbspExpr::BinaryExpr { left, op, right } => { + assert!(matches!(op, ast::Operator::Equals)); + assert!(matches!(&**left, DbspExpr::Column(name) if name == $col)); + assert!(matches!(&**right, DbspExpr::Literal(Value::Integer($val)))); + } + _ => panic!("Expected binary expression in filter"), + }, + _ => panic!("Expected Filter operator at level {}", $level), + } + }}; + } + + // Helper to get node at specific level from root + fn get_node_at_level(circuit: &DbspCircuit, level: usize) -> &DbspNode { + let mut current_id = circuit.root.expect("Circuit has no root"); + for _ in 0..level { + let node = circuit.nodes.get(¤t_id).expect("Node not found"); + if node.inputs.is_empty() { + panic!("No more levels available, requested level {level}"); + } + current_id = node.inputs[0]; + } + circuit.nodes.get(¤t_id).expect("Node not found") + } + + // Helper to get the current accumulated state of the circuit (from the root operator) + // This returns the internal state including bookkeeping entries + fn get_current_state(circuit: &DbspCircuit) -> Result { + if let Some(root_id) = circuit.root { + let node = circuit + .nodes + .get(&root_id) + .ok_or_else(|| LimboError::ParseError("Root node not found".to_string()))?; + + if let Some(ref executable) = node.executable { + Ok(executable.get_current_state()) + } else { + // Input nodes don't have executables but also don't have state + Ok(Delta::new()) + } + } else { + Err(LimboError::ParseError( + "Circuit has no root node".to_string(), + )) + } + } + + // Helper to create a DeltaSet from a HashMap (for tests) + fn delta_set_from_map(map: HashMap) -> DeltaSet { + let mut delta_set = DeltaSet::new(); + for (key, value) in map { + delta_set.insert(key, value); + } + delta_set + } + + #[test] + fn test_simple_projection() { + let circuit = compile_sql!("SELECT name FROM users"); + + // Circuit has 2 nodes with Projection at root + assert_circuit!(circuit, depth: 2, root: Projection); + + // Verify operators at each level + assert_operator!(circuit, 0, Projection { columns: ["name"] }); + assert_operator!(circuit, 1, Input { name: "users" }); + } + + #[test] + fn test_filter_with_projection() { + let circuit = compile_sql!("SELECT name FROM users WHERE age > 18"); + + // Circuit has 3 nodes with Projection at root + assert_circuit!(circuit, depth: 3, root: Projection); + + // Verify operators at each level + assert_operator!(circuit, 0, Projection { columns: ["name"] }); + assert_operator!(circuit, 1, Filter); + assert_filter_predicate!(circuit, 1, "age" > 18); + assert_operator!(circuit, 2, Input { name: "users" }); + } + + #[test] + fn test_select_star() { + let mut circuit = compile_sql!("SELECT * FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have all rows with all columns + assert_eq!(result.changes.len(), 2); + + // Verify both rows are present with all columns + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 3); // id, name, age + } + } + + #[test] + fn test_execute_filter() { + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should only have Alice and Charlie (age > 18) + assert_eq!( + result.changes.len(), + 2, + "Expected 2 rows after filtering, got {}", + result.changes.len() + ); + + // Check that the filtered rows are correct + let names: Vec = result + .changes + .iter() + .filter_map(|(row, weight)| { + if *weight > 0 && row.values.len() > 1 { + if let Value::Text(name) = &row.values[1] { + Some(name.to_string()) + } else { + None + } + } else { + None + } + }) + .collect(); + + assert!( + names.contains(&"Alice".to_string()), + "Alice should be in results" + ); + assert!( + names.contains(&"Charlie".to_string()), + "Charlie should be in results" + ); + assert!( + !names.contains(&"Bob".to_string()), + "Bob should not be in results" + ); + } + + #[test] + fn test_simple_column_projection() { + let mut circuit = compile_sql!("SELECT name, age FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have all rows but only 2 columns (name, age) + assert_eq!(result.changes.len(), 2); + + for (row, _) in &result.changes { + assert_eq!(row.values.len(), 2); // Only name and age + // First value should be name (Text) + assert!(matches!(&row.values[0], Value::Text(_))); + // Second value should be age (Integer) + assert!(matches!(&row.values[1], Value::Integer(_))); + } + } + + #[test] + fn test_simple_aggregation() { + // Test COUNT(*) with GROUP BY + let mut circuit = compile_sql!("SELECT age, COUNT(*) FROM users GROUP BY age"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have 2 groups: age 25 with count 2, age 30 with count 1 + assert_eq!(result.changes.len(), 2); + + // Check the results + let mut found_25 = false; + let mut found_30 = false; + + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 2); // age, count + + if let (Value::Integer(age), Value::Integer(count)) = (&row.values[0], &row.values[1]) { + if *age == 25 { + assert_eq!(*count, 2, "Age 25 should have count 2"); + found_25 = true; + } else if *age == 30 { + assert_eq!(*count, 1, "Age 30 should have count 1"); + found_30 = true; + } + } + } + + assert!(found_25, "Should have group for age 25"); + assert!(found_30, "Should have group for age 30"); + } + + #[test] + fn test_sum_aggregation() { + // Test SUM with GROUP BY + let mut circuit = compile_sql!("SELECT name, SUM(age) FROM users GROUP BY name"); + + // Create test data - some names appear multiple times + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Alice".into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Bob".into()), + Value::Integer(20), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have 2 groups: Alice with sum 55, Bob with sum 20 + assert_eq!(result.changes.len(), 2); + + for (row, weight) in &result.changes { + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 2); // name, sum + + if let (Value::Text(name), Value::Float(sum)) = (&row.values[0], &row.values[1]) { + if name.as_str() == "Alice" { + assert_eq!(*sum, 55.0, "Alice should have sum 55"); + } else if name.as_str() == "Bob" { + assert_eq!(*sum, 20.0, "Bob should have sum 20"); + } + } + } + } + + #[test] + fn test_aggregation_without_group_by() { + // Test aggregation without GROUP BY - should produce a single row + let mut circuit = compile_sql!("SELECT COUNT(*), SUM(age), AVG(age) FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(20), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + // Should have exactly 1 row with all aggregates + assert_eq!( + result.changes.len(), + 1, + "Should have exactly one result row" + ); + + let (row, weight) = result.changes.first().unwrap(); + assert_eq!(*weight, 1); + assert_eq!(row.values.len(), 3); // count, sum, avg + + // Check aggregate results + // COUNT should be Integer + if let Value::Integer(count) = &row.values[0] { + assert_eq!(*count, 3, "COUNT(*) should be 3"); + } else { + panic!("COUNT should be Integer, got {:?}", row.values[0]); + } + + // SUM can be Integer (if whole number) or Float + match &row.values[1] { + Value::Integer(sum) => assert_eq!(*sum, 75, "SUM(age) should be 75"), + Value::Float(sum) => assert_eq!(*sum, 75.0, "SUM(age) should be 75.0"), + other => panic!("SUM should be Integer or Float, got {other:?}"), + } + + // AVG should be Float + if let Value::Float(avg) = &row.values[2] { + assert_eq!(*avg, 25.0, "AVG(age) should be 25.0"); + } else { + panic!("AVG should be Float, got {:?}", row.values[2]); + } + } + + #[test] + fn test_expression_projection_execution() { + // Test that complex expressions work through VDBE compilation + let mut circuit = compile_sql!("SELECT hex(id) FROM users"); + + // Create test data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(255), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + + // Create input map + let mut inputs = HashMap::new(); + inputs.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(inputs).unwrap(); + + assert_eq!(result.changes.len(), 2); + + let hex_values: HashMap = result + .changes + .iter() + .map(|(row, _)| { + let rowid = row.rowid; + if let Value::Text(text) = &row.values[0] { + (rowid, text.to_string()) + } else { + panic!("Expected Text value for hex() result"); + } + }) + .collect(); + + assert_eq!( + hex_values.get(&1).unwrap(), + "31", + "hex(1) should return '31' (hex of ASCII '1')" + ); + + assert_eq!( + hex_values.get(&2).unwrap(), + "323535", + "hex(255) should return '323535' (hex of ASCII '2', '5', '5')" + ); + } + + // TODO: This test currently fails on incremental updates. + // The initial execution works correctly, but incremental updates produce + // incorrect results (3 changes instead of 2, with wrong values). + // This tests that the aggregate operator correctly handles incremental + // updates when it's sandwiched between projection operators. + #[test] + fn test_projection_aggregation_projection_pattern() { + // Test pattern: projection -> aggregation -> projection + // Query: SELECT HEX(SUM(age + 2)) FROM users + let mut circuit = compile_sql!("SELECT HEX(SUM(age + 2)) FROM users"); + + // Initial input data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".to_string().into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".to_string().into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".to_string().into()), + Value::Integer(35), + ], + ); + + let mut input_data = HashMap::new(); + input_data.insert("users".to_string(), input_delta); + + // Initialize the circuit with the initial data + let result = circuit.initialize(input_data).unwrap(); + + // Expected: SUM(age + 2) = (25+2) + (30+2) + (35+2) = 27 + 32 + 37 = 96 + // HEX(96) should be the hex representation of the string "96" = "3936" + assert_eq!(result.changes.len(), 1); + let (row, _weight) = &result.changes[0]; + assert_eq!(row.values.len(), 1); + + // The hex function converts the number to string first, then to hex + // 96 as string is "96", which in hex is "3936" (hex of ASCII '9' and '6') + assert_eq!( + row.values[0], + Value::Text("3936".to_string().into()), + "HEX(SUM(age + 2)) should return '3936' for sum of 96" + ); + + // Test incremental update: add a new user + let mut input_delta = Delta::new(); + input_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".to_string().into()), + Value::Integer(40), + ], + ); + + let mut input_data = HashMap::new(); + input_data.insert("users".to_string(), input_delta); + + let result = circuit.execute(input_data, DeltaSet::empty()).unwrap(); + + // Expected: new SUM(age + 2) = 96 + (40+2) = 138 + // HEX(138) = hex of "138" = "313338" + assert_eq!(result.changes.len(), 2); + + // First change: remove old aggregate (96) + let (row, weight) = &result.changes[0]; + assert_eq!(*weight, -1); + assert_eq!(row.values[0], Value::Text("3936".to_string().into())); + + // Second change: add new aggregate (138) + let (row, weight) = &result.changes[1]; + assert_eq!(*weight, 1); + assert_eq!( + row.values[0], + Value::Text("313338".to_string().into()), + "HEX(SUM(age + 2)) should return '313338' for sum of 138" + ); + } + + #[test] + fn test_nested_projection_with_groupby() { + // Test pattern: projection -> aggregation with GROUP BY -> projection + // Query: SELECT name, HEX(SUM(age * 2)) FROM users GROUP BY name + let mut circuit = compile_sql!("SELECT name, HEX(SUM(age * 2)) FROM users GROUP BY name"); + + // Initial input data + let mut input_delta = Delta::new(); + input_delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".to_string().into()), + Value::Integer(25), + ], + ); + input_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".to_string().into()), + Value::Integer(30), + ], + ); + input_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Alice".to_string().into()), + Value::Integer(35), + ], + ); + + let mut input_data = HashMap::new(); + input_data.insert("users".to_string(), input_delta); + + // Initialize circuit with initial data + let result = circuit.initialize(input_data).unwrap(); + + // Expected results: + // Alice: SUM(25*2 + 35*2) = 50 + 70 = 120, HEX("120") = "313230" + // Bob: SUM(30*2) = 60, HEX("60") = "3630" + assert_eq!(result.changes.len(), 2); + + let results: HashMap = result + .changes + .iter() + .map(|(row, _weight)| { + let name = match &row.values[0] { + Value::Text(t) => t.to_string(), + _ => panic!("Expected text for name"), + }; + let hex_sum = match &row.values[1] { + Value::Text(t) => t.to_string(), + _ => panic!("Expected text for hex value"), + }; + (name, hex_sum) + }) + .collect(); + + assert_eq!( + results.get("Alice").unwrap(), + "313230", + "Alice's HEX(SUM(age * 2)) should be '313230' (120)" + ); + assert_eq!( + results.get("Bob").unwrap(), + "3630", + "Bob's HEX(SUM(age * 2)) should be '3630' (60)" + ); + } + + #[test] + fn test_transaction_context() { + // Test that uncommitted changes are visible within a transaction + // but don't affect the operator's internal state + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: only Alice (age > 18) + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.values[1], Value::Text("Alice".into())); + + // Create uncommitted changes that would be visible in a transaction + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + // Add Charlie (age 30) - should be visible in transaction + uncommitted_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + // Add David (age 15) - should NOT be visible (filtered out) + uncommitted_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(15), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted data - this simulates processing the uncommitted changes + // through the circuit to see what would be visible + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // The result should show Charlie being added (passes filter, age > 18) + // David is filtered out (age 15 < 18) + assert_eq!(tx_result.changes.len(), 1, "Should see Charlie added"); + assert_eq!( + tx_result.changes[0].0.values[1], + Value::Text("Charlie".into()) + ); + + // Now actually commit Charlie (without uncommitted context) + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(30), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // The commit result should show Charlie being added + assert_eq!(commit_result.changes.len(), 1, "Should see Charlie added"); + assert_eq!( + commit_result.changes[0].0.values[1], + Value::Text("Charlie".into()) + ); + + // Commit the change to make it permanent + circuit.commit(commit_data).unwrap(); + + // Now if we execute again with no changes, we should see no delta + let empty_result = circuit.execute(HashMap::new(), DeltaSet::empty()).unwrap(); + assert_eq!(empty_result.changes.len(), 0, "No changes when no new data"); + } + + #[test] + fn test_uncommitted_delete() { + // Test that uncommitted deletes are handled correctly without affecting operator state + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(20), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: Alice, Bob, Charlie (all age > 18) + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 3); + + // Create uncommitted delete for Bob + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted delete + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Result should show the deleted row that passed the filter + assert_eq!( + tx_result.changes.len(), + 1, + "Should see the uncommitted delete" + ); + + // Verify operator's internal state is unchanged (still has all 3 users) + let state_after = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after.changes.len(), + 3, + "Internal state should still have all 3 users" + ); + + // Now actually commit the delete + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Actually commit the delete to update operator state + circuit.commit(commit_data).unwrap(); + + // The commit result should show Bob being deleted + assert_eq!(commit_result.changes.len(), 1, "Should see Bob deleted"); + assert_eq!( + commit_result.changes[0].1, -1, + "Delete should have weight -1" + ); + assert_eq!( + commit_result.changes[0].0.values[1], + Value::Text("Bob".into()) + ); + + // After commit, internal state should have only Alice and Charlie + let final_state = get_current_state(&circuit).unwrap(); + assert_eq!( + final_state.changes.len(), + 2, + "After commit, should have Alice and Charlie" + ); + + let names: Vec = final_state + .changes + .iter() + .map(|(row, _)| { + if let Value::Text(name) = &row.values[1] { + name.to_string() + } else { + panic!("Expected text value"); + } + }) + .collect(); + assert!(names.contains(&"Alice".to_string())); + assert!(names.contains(&"Charlie".to_string())); + assert!(!names.contains(&"Bob".to_string())); + } + + #[test] + fn test_uncommitted_update() { + // Test that uncommitted updates (delete + insert) are handled correctly + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); // Bob is 17, filtered out + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Create uncommitted update: Bob turns 19 (update from 17 to 19) + // This is modeled as delete + insert + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + uncommitted_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(19), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted update + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Bob should now appear in the result (age 19 > 18) + // Consolidate to see the final state + let mut final_result = tx_result; + final_result.consolidate(); + + assert_eq!(final_result.changes.len(), 1, "Bob should now be in view"); + assert_eq!( + final_result.changes[0].0.values[1], + Value::Text("Bob".into()) + ); + assert_eq!(final_result.changes[0].0.values[2], Value::Integer(19)); + + // Now actually commit the update + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(17), + ], + ); + commit_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(19), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + // Commit the update + circuit.commit(commit_data).unwrap(); + + // After committing, Bob should be in the view's state + let state = get_current_state(&circuit).unwrap(); + let mut consolidated_state = state; + consolidated_state.consolidate(); + + // Should have both Alice and Bob now + assert_eq!( + consolidated_state.changes.len(), + 2, + "Should have Alice and Bob" + ); + + let names: Vec = consolidated_state + .changes + .iter() + .map(|(row, _)| { + if let Value::Text(name) = &row.values[1] { + name.as_str().to_string() + } else { + panic!("Expected text value"); + } + }) + .collect(); + assert!(names.contains(&"Alice".to_string())); + assert!(names.contains(&"Bob".to_string())); + } + + #[test] + fn test_uncommitted_filtered_delete() { + // Test deleting a row that doesn't pass the filter + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with mixed data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(15), + ], + ); // Bob doesn't pass filter + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Create uncommitted delete for Bob (who isn't in the view because age=15) + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(15), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted delete - should produce no output changes + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted)) + .unwrap(); + + // Bob wasn't in the view, so deleting him produces no output + assert_eq!( + tx_result.changes.len(), + 0, + "Deleting filtered row produces no changes" + ); + + // The view state should still only have Alice + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 1, "View still has only Alice"); + assert_eq!(state.changes[0].0.values[1], Value::Text("Alice".into())); + } + + #[test] + fn test_uncommitted_mixed_operations() { + // Test multiple uncommitted operations together + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 2); + + // Create uncommitted changes: + // - Delete Alice + // - Update Bob's age to 35 + // - Insert Charlie (age 40) + // - Insert David (age 16, filtered out) + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + // Delete Alice + uncommitted_delta.delete( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + // Update Bob (delete + insert) + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + uncommitted_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(35), + ], + ); + // Insert Charlie + uncommitted_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(40), + ], + ); + // Insert David (will be filtered) + uncommitted_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(16), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted changes + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Result should show all changes: delete Alice, update Bob, insert Charlie and David + assert_eq!( + tx_result.changes.len(), + 4, + "Should see all uncommitted mixed operations" + ); + + // Verify operator's internal state is unchanged + let state_after = get_current_state(&circuit).unwrap(); + assert_eq!(state_after.changes.len(), 2, "Still has Alice and Bob"); + + // Commit all changes + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.delete( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + commit_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + commit_delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(35), + ], + ); + commit_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(40), + ], + ); + commit_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(16), + ], + ); + commit_data.insert("users".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Should see: Alice deleted, Bob deleted, Bob inserted, Charlie inserted + // (David filtered out) + assert_eq!(commit_result.changes.len(), 4, "Should see 4 changes"); + + // Actually commit the changes to update operator state + circuit.commit(commit_data).unwrap(); + + // After all commits, execute with no changes should return empty delta + let empty_result = circuit.execute(HashMap::new(), DeltaSet::empty()).unwrap(); + assert_eq!(empty_result.changes.len(), 0, "No changes when no new data"); + } + + #[test] + fn test_uncommitted_aggregation() { + // Test that aggregations work correctly with uncommitted changes + // This tests the specific scenario where a transaction adds new data + // and we need to see correct aggregation results within the transaction + + // Create a sales table schema for testing + let mut schema = Schema::new(false); + let sales_table = BTreeTable { + name: "sales".to_string(), + root_page: 2, + primary_key_columns: vec![], + columns: vec![ + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("amount".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(sales_table)); + + // Parse and compile the aggregation query + let sql = "SELECT product_id, SUM(amount) as total, COUNT(*) as cnt FROM sales GROUP BY product_id"; + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next().unwrap().unwrap(); + + let mut circuit = match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + let logical_plan = builder.build_statement(&stmt).unwrap(); + DbspCompiler::new().compile(&logical_plan).unwrap() + } + _ => panic!("Expected SQL statement"), + }; + + // Initialize with base data: (1, 100), (1, 200), (2, 150), (2, 250) + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert(1, vec![Value::Integer(1), Value::Integer(100)]); + delta.insert(2, vec![Value::Integer(1), Value::Integer(200)]); + delta.insert(3, vec![Value::Integer(2), Value::Integer(150)]); + delta.insert(4, vec![Value::Integer(2), Value::Integer(250)]); + init_data.insert("sales".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: product 1 total=300, product 2 total=400 + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 2, "Should have 2 product groups"); + + // Build a map of product_id -> (total, count) + let initial_results: HashMap = state + .changes + .iter() + .map(|(row, _)| { + // SUM might return Integer or Float, COUNT returns Integer + let product_id = match &row.values[0] { + Value::Integer(id) => *id, + _ => panic!("Product ID should be Integer, got {:?}", row.values[0]), + }; + + let total = match &row.values[1] { + Value::Integer(t) => *t, + Value::Float(t) => *t as i64, + _ => panic!("Total should be numeric, got {:?}", row.values[1]), + }; + + let count = match &row.values[2] { + Value::Integer(c) => *c, + _ => panic!("Count should be Integer, got {:?}", row.values[2]), + }; + + (product_id, (total, count)) + }) + .collect(); + + assert_eq!( + initial_results.get(&1).unwrap(), + &(300, 2), + "Product 1 should have total=300, count=2" + ); + assert_eq!( + initial_results.get(&2).unwrap(), + &(400, 2), + "Product 2 should have total=400, count=2" + ); + + // Create uncommitted changes: INSERT (1, 50), (3, 300) + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + uncommitted_delta.insert(5, vec![Value::Integer(1), Value::Integer(50)]); // Add to product 1 + uncommitted_delta.insert(6, vec![Value::Integer(3), Value::Integer(300)]); // New product 3 + uncommitted.insert("sales".to_string(), uncommitted_delta); + + // Execute with uncommitted data - simulating a read within transaction + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Result should show the aggregate changes from uncommitted data + // Product 1: retraction of (300, 2) and insertion of (350, 3) + // Product 3: insertion of (300, 1) - new product + assert_eq!( + tx_result.changes.len(), + 3, + "Should see aggregate changes from uncommitted data" + ); + + // IMPORTANT: Verify operator's internal state is unchanged + let state_after = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after.changes.len(), + 2, + "Internal state should still have 2 groups" + ); + + // Verify the internal state still has original values + let state_results: HashMap = state_after + .changes + .iter() + .map(|(row, _)| { + let product_id = match &row.values[0] { + Value::Integer(id) => *id, + _ => panic!("Product ID should be Integer"), + }; + + let total = match &row.values[1] { + Value::Integer(t) => *t, + Value::Float(t) => *t as i64, + _ => panic!("Total should be numeric"), + }; + + let count = match &row.values[2] { + Value::Integer(c) => *c, + _ => panic!("Count should be Integer"), + }; + + (product_id, (total, count)) + }) + .collect(); + + assert_eq!( + state_results.get(&1).unwrap(), + &(300, 2), + "Product 1 unchanged" + ); + assert_eq!( + state_results.get(&2).unwrap(), + &(400, 2), + "Product 2 unchanged" + ); + assert!( + !state_results.contains_key(&3), + "Product 3 should not be in committed state" + ); + + // Now actually commit the changes + let mut commit_data = HashMap::new(); + let mut commit_delta = Delta::new(); + commit_delta.insert(5, vec![Value::Integer(1), Value::Integer(50)]); + commit_delta.insert(6, vec![Value::Integer(3), Value::Integer(300)]); + commit_data.insert("sales".to_string(), commit_delta); + + let commit_result = circuit + .execute(commit_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Should see changes for product 1 (updated) and product 3 (new) + assert_eq!( + commit_result.changes.len(), + 3, + "Should see 3 changes (delete old product 1, insert new product 1, insert product 3)" + ); + + // Actually commit the changes to update operator state + circuit.commit(commit_data).unwrap(); + + // After commit, verify final state + let final_state = get_current_state(&circuit).unwrap(); + assert_eq!( + final_state.changes.len(), + 3, + "Should have 3 product groups after commit" + ); + + let final_results: HashMap = final_state + .changes + .iter() + .map(|(row, _)| { + let product_id = match &row.values[0] { + Value::Integer(id) => *id, + _ => panic!("Product ID should be Integer"), + }; + + let total = match &row.values[1] { + Value::Integer(t) => *t, + Value::Float(t) => *t as i64, + _ => panic!("Total should be numeric"), + }; + + let count = match &row.values[2] { + Value::Integer(c) => *c, + _ => panic!("Count should be Integer"), + }; + + (product_id, (total, count)) + }) + .collect(); + + assert_eq!( + final_results.get(&1).unwrap(), + &(350, 3), + "Product 1 should have total=350, count=3" + ); + assert_eq!( + final_results.get(&2).unwrap(), + &(400, 2), + "Product 2 should have total=400, count=2" + ); + assert_eq!( + final_results.get(&3).unwrap(), + &(300, 1), + "Product 3 should have total=300, count=1" + ); + } + + #[test] + fn test_uncommitted_data_visible_in_transaction() { + // Test that uncommitted INSERTs are visible within the same transaction + // This simulates: BEGIN; INSERT ...; SELECT * FROM view; COMMIT; + + let mut circuit = compile_sql!("SELECT * FROM users WHERE age > 18"); + + // Initialize with some data - need to match the schema (id, name, age) + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data.clone()).unwrap(); + + // Verify initial state + let state = get_current_state(&circuit).unwrap(); + assert_eq!( + state.len(), + 2, + "Should have 2 users initially (both pass age > 18 filter)" + ); + + // Simulate a transaction: INSERT new users that pass the filter - match schema (id, name, age) + let mut uncommitted = HashMap::new(); + let mut tx_delta = Delta::new(); + tx_delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(35), + ], + ); + tx_delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(20), + ], + ); + uncommitted.insert("users".to_string(), tx_delta); + + // Execute with uncommitted data - this should return the uncommitted changes + // that passed through the filter (age > 18) + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // IMPORTANT: tx_result should contain the filtered uncommitted changes! + // Both Charlie (35) and David (20) should pass the age > 18 filter + assert_eq!( + tx_result.len(), + 2, + "Should see 2 uncommitted rows that pass filter" + ); + + // Verify the uncommitted results contain the expected rows + let has_charlie = tx_result.changes.iter().any(|(row, _)| row.rowid == 3); + assert!( + has_charlie, + "Should find Charlie (rowid=3) in uncommitted results" + ); + + let has_david = tx_result.changes.iter().any(|(row, _)| row.rowid == 4); + assert!( + has_david, + "Should find David (rowid=4) in uncommitted results" + ); + + // CRITICAL: Verify the operator state wasn't modified by uncommitted execution + let state_after_uncommitted = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after_uncommitted.len(), + 2, + "State should STILL be 2 after uncommitted execution - only Alice and Bob" + ); + + // The state should not contain Charlie or David + let has_charlie_in_state = state_after_uncommitted + .changes + .iter() + .any(|(row, _)| row.rowid == 3); + let has_david_in_state = state_after_uncommitted + .changes + .iter() + .any(|(row, _)| row.rowid == 4); + assert!( + !has_charlie_in_state, + "Charlie should NOT be in operator state (uncommitted)" + ); + assert!( + !has_david_in_state, + "David should NOT be in operator state (uncommitted)" + ); + } + + #[test] + fn test_uncommitted_aggregation_with_rollback() { + // Test that rollback properly discards uncommitted aggregation changes + // Similar to test_uncommitted_aggregation but explicitly tests rollback semantics + + // Create a simple aggregation circuit + let mut circuit = compile_sql!("SELECT age, COUNT(*) as cnt FROM users GROUP BY age"); + + // Initialize with some data + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert( + 1, + vec![ + Value::Integer(1), + Value::Text("Alice".into()), + Value::Integer(25), + ], + ); + delta.insert( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + delta.insert( + 3, + vec![ + Value::Integer(3), + Value::Text("Charlie".into()), + Value::Integer(25), + ], + ); + delta.insert( + 4, + vec![ + Value::Integer(4), + Value::Text("David".into()), + Value::Integer(30), + ], + ); + init_data.insert("users".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state: age 25 count=2, age 30 count=2 + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 2); + + let initial_counts: HashMap = state + .changes + .iter() + .map(|(row, _)| { + if let (Value::Integer(age), Value::Integer(count)) = + (&row.values[0], &row.values[1]) + { + (*age, *count) + } else { + panic!("Unexpected value types"); + } + }) + .collect(); + + assert_eq!(initial_counts.get(&25).unwrap(), &2); + assert_eq!(initial_counts.get(&30).unwrap(), &2); + + // Create uncommitted changes that would affect aggregations + let mut uncommitted = HashMap::new(); + let mut uncommitted_delta = Delta::new(); + // Add more people aged 25 + uncommitted_delta.insert( + 5, + vec![ + Value::Integer(5), + Value::Text("Eve".into()), + Value::Integer(25), + ], + ); + uncommitted_delta.insert( + 6, + vec![ + Value::Integer(6), + Value::Text("Frank".into()), + Value::Integer(25), + ], + ); + // Add person aged 35 (new group) + uncommitted_delta.insert( + 7, + vec![ + Value::Integer(7), + Value::Text("Grace".into()), + Value::Integer(35), + ], + ); + // Delete Bob (age 30) + uncommitted_delta.delete( + 2, + vec![ + Value::Integer(2), + Value::Text("Bob".into()), + Value::Integer(30), + ], + ); + uncommitted.insert("users".to_string(), uncommitted_delta); + + // Execute with uncommitted changes + let tx_result = circuit + .execute(HashMap::new(), delta_set_from_map(uncommitted.clone())) + .unwrap(); + + // Should see the aggregate changes from uncommitted data + // Age 25: retraction of count 1 and insertion of count 2 + // Age 30: insertion of count 1 (Bob is new for age 30) + assert!( + !tx_result.changes.is_empty(), + "Should see aggregate changes from uncommitted data" + ); + + // Verify internal state is unchanged (simulating rollback by not committing) + let state_after_rollback = get_current_state(&circuit).unwrap(); + assert_eq!( + state_after_rollback.changes.len(), + 2, + "Should still have 2 age groups" + ); + + let rollback_counts: HashMap = state_after_rollback + .changes + .iter() + .map(|(row, _)| { + if let (Value::Integer(age), Value::Integer(count)) = + (&row.values[0], &row.values[1]) + { + (*age, *count) + } else { + panic!("Unexpected value types"); + } + }) + .collect(); + + // Verify counts are unchanged after rollback + assert_eq!( + rollback_counts.get(&25).unwrap(), + &2, + "Age 25 count unchanged" + ); + assert_eq!( + rollback_counts.get(&30).unwrap(), + &2, + "Age 30 count unchanged" + ); + assert!( + !rollback_counts.contains_key(&35), + "Age 35 should not exist" + ); + } + + #[test] + fn test_circuit_rowid_update_consolidation() { + // Test that circuit properly consolidates state when rowid changes + let mut circuit = DbspCircuit::new(); + + // Create a simple filter node + let schema = Arc::new(LogicalSchema::new(vec![ + ("id".to_string(), Type::Integer), + ("value".to_string(), Type::Integer), + ])); + + // First create an input node + let input_id = circuit.add_node( + DbspOperator::Input { + name: "test".to_string(), + schema: schema.clone(), + }, + vec![], + None, // Input nodes don't have executables + ); + + let filter_op = FilterOperator::new( + FilterPredicate::GreaterThan { + column: "value".to_string(), + value: Value::Integer(10), + }, + vec!["id".to_string(), "value".to_string()], + ); + + // Create the filter predicate using DbspExpr + let predicate = DbspExpr::BinaryExpr { + left: Box::new(DbspExpr::Column("value".to_string())), + op: ast::Operator::Greater, + right: Box::new(DbspExpr::Literal(Value::Integer(10))), + }; + + let filter_id = circuit.add_node( + DbspOperator::Filter { predicate }, + vec![input_id], // Filter takes input from the input node + Some(Box::new(filter_op)), + ); + + circuit.root = Some(filter_id); + + // Initialize with a row + let mut init_data = HashMap::new(); + let mut delta = Delta::new(); + delta.insert(5, vec![Value::Integer(5), Value::Integer(20)]); + init_data.insert("test".to_string(), delta); + + circuit.initialize(init_data).unwrap(); + + // Verify initial state + let state = get_current_state(&circuit).unwrap(); + assert_eq!(state.changes.len(), 1); + assert_eq!(state.changes[0].0.rowid, 5); + + // Now update the rowid from 5 to 3 + let mut update_data = HashMap::new(); + let mut update_delta = Delta::new(); + update_delta.delete(5, vec![Value::Integer(5), Value::Integer(20)]); + update_delta.insert(3, vec![Value::Integer(3), Value::Integer(20)]); + update_data.insert("test".to_string(), update_delta); + + circuit + .execute(update_data.clone(), DeltaSet::empty()) + .unwrap(); + + // Commit the changes to update operator state + circuit.commit(update_data).unwrap(); + + // The circuit should consolidate the state properly + let final_state = get_current_state(&circuit).unwrap(); + assert_eq!( + final_state.changes.len(), + 1, + "Circuit should consolidate to single row" + ); + assert_eq!(final_state.changes[0].0.rowid, 3); + assert_eq!( + final_state.changes[0].0.values, + vec![Value::Integer(3), Value::Integer(20)] + ); + assert_eq!(final_state.changes[0].1, 1); + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index 328f1a510..4c26b91ba 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,3 +1,4 @@ +pub mod compiler; pub mod dbsp; pub mod expr_compiler; pub mod hashable_row; diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 75366ed06..85f7e640c 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -359,7 +359,7 @@ pub enum JoinType { Right, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum AggregateFunction { Count, Sum(String), @@ -774,6 +774,46 @@ impl ProjectOperator { }) } + /// Create a ProjectOperator from pre-compiled expressions + pub fn from_compiled( + compiled_exprs: Vec, + aliases: Vec>, + input_column_names: Vec, + output_column_names: Vec, + ) -> crate::Result { + // Set up internal connection for expression evaluation + let io = Arc::new(crate::MemoryIO::new()); + let db = Database::open_file( + io, ":memory:", false, // no MVCC needed for expression evaluation + false, // no indexes needed + )?; + let internal_conn = db.connect()?; + // Set to read-only mode and disable auto-commit since we're only evaluating expressions + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + // Create ProjectColumn structs from compiled expressions + let columns: Vec = compiled_exprs + .into_iter() + .zip(aliases) + .map(|(compiled, alias)| ProjectColumn { + // Create a placeholder AST expression since we already have the compiled version + expr: turso_parser::ast::Expr::Literal(turso_parser::ast::Literal::Null), + alias, + compiled, + }) + .collect(); + + Ok(Self { + columns, + input_column_names, + output_column_names, + current_state: Delta::new(), + tracker: None, + internal_conn, + }) + } + /// Get the columns for this projection pub fn columns(&self) -> &[ProjectColumn] { &self.columns diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 2de196a8f..9f54c96c0 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -21,7 +21,6 @@ pub(crate) mod group_by; pub(crate) mod index; pub(crate) mod insert; pub(crate) mod integrity_check; -#[cfg(test)] pub(crate) mod logical; pub(crate) mod main_loop; pub(crate) mod optimizer; From 565c2a698af7a99fa39b7235a63ca23cff651efe Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Wed, 27 Aug 2025 10:38:11 -0500 Subject: [PATCH 8/9] adjust views to use circuits --- core/incremental/view.rs | 861 ++++++--------------------------------- 1 file changed, 128 insertions(+), 733 deletions(-) diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 7033fe83c..bcae0afae 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,13 +1,12 @@ +use super::compiler::{DbspCircuit, DbspCompiler, DeltaSet}; use super::dbsp::{RowKeyStream, RowKeyZSet}; -use super::operator::{ - AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator, - FilterPredicate, IncrementalOperator, ProjectOperator, -}; +use super::operator::{ComputationTracker, Delta, FilterPredicate}; use crate::schema::{BTreeTable, Column, Schema}; +use crate::translate::logical::LogicalPlanBuilder; use crate::types::{IOCompletions, IOResult, Value}; -use crate::util::{extract_column_name_from_expr, extract_view_columns}; +use crate::util::extract_view_columns; use crate::{io_yield_one, Completion, LimboError, Result, Statement}; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::sync::{Arc, Mutex}; use turso_parser::ast; @@ -60,8 +59,7 @@ pub struct ViewTransactionState { /// for large aggregations, because then we don't have to re-compute when opening the database /// again. /// -/// Right now we are supporting the simplest views by keeping the operators in the view and -/// applying them in a sane order. But the general solution would turn this into a DBSP circuit. +/// Uses DBSP circuits for incremental computation. #[derive(Debug)] pub struct IncrementalView { // Stream of row keys for this view @@ -75,12 +73,11 @@ pub struct IncrementalView { // The SELECT statement that defines how to transform input data pub select_stmt: ast::Select, - // Internal filter operator for predicate evaluation - filter_operator: Option, - // Internal project operator for value transformation - project_operator: Option, - // Internal aggregate operator for GROUP BY and aggregations - aggregate_operator: Option, + // DBSP circuit that encapsulates the computation + circuit: DbspCircuit, + // Track whether circuit has been initialized with data + circuit_initialized: bool, + // Tables referenced by this view (extracted from FROM clause and JOINs) base_table: Arc, // The view's output columns with their types @@ -108,6 +105,25 @@ impl IncrementalView { Ok(()) } + /// Try to compile the SELECT statement into a DBSP circuit + fn try_compile_circuit( + select: &ast::Select, + schema: &Schema, + _base_table: &Arc, + ) -> Result { + // Build the logical plan from the SELECT statement + let mut builder = LogicalPlanBuilder::new(schema); + // Convert Select to a Stmt for the builder + let stmt = ast::Stmt::Select(select.clone()); + let logical_plan = builder.build_statement(&stmt)?; + + // Compile the logical plan to a DBSP circuit + let compiler = DbspCompiler::new(); + let circuit = compiler.compile(&logical_plan)?; + + Ok(circuit) + } + /// Get an iterator over column names, using enumerated naming for unnamed columns pub fn column_names(&self) -> impl Iterator + '_ { self.columns.iter().enumerate().map(|(i, col)| { @@ -136,14 +152,6 @@ impl IncrementalView { false } - /// Apply filter operator to check if values pass the view's WHERE clause - fn apply_filter(&self, values: &[Value]) -> bool { - if let Some(ref filter_op) = self.filter_operator { - filter_op.evaluate_predicate(values) - } else { - true - } - } pub fn from_sql(sql: &str, schema: &Schema) -> Result { let mut parser = Parser::new(sql.as_bytes()); let cmd = parser.next_cmd()?; @@ -173,10 +181,6 @@ impl IncrementalView { // Extract output columns using the shared function let view_columns = extract_view_columns(&select, schema); - // Extract GROUP BY columns and aggregate functions - let (group_by_columns, aggregate_functions, _old_output_names) = - Self::extract_aggregation_info(&select); - let (join_tables, join_condition) = Self::extract_join_info(&select); if join_tables.is_some() || join_condition.is_some() { return Err(LimboError::ParseError( @@ -199,105 +203,43 @@ impl IncrementalView { )); }; - let base_table_column_names = base_table - .columns - .iter() - .enumerate() - .map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}"))) - .collect(); - Self::new( name, - Vec::new(), // Empty initial data where_predicate, select.clone(), base_table, - base_table_column_names, view_columns, - group_by_columns, - aggregate_functions, schema, ) } - #[allow(clippy::too_many_arguments)] pub fn new( name: String, - initial_data: Vec<(i64, Vec)>, where_predicate: FilterPredicate, select_stmt: ast::Select, base_table: Arc, - base_table_column_names: Vec, columns: Vec, - group_by_columns: Vec, - aggregate_functions: Vec, schema: &Schema, ) -> Result { - let mut records = BTreeMap::new(); - - for (row_key, values) in initial_data { - records.insert(row_key, values); - } - - // Create initial stream with row keys - let mut zset = RowKeyZSet::new(); - for (row_key, values) in &records { - use crate::incremental::hashable_row::HashableRow; - let row = HashableRow::new(*row_key, values.clone()); - zset.insert(row, 1); - } + let records = BTreeMap::new(); // Create the tracker that will be shared by all operators let tracker = Arc::new(Mutex::new(ComputationTracker::new())); - // Create filter operator if we have a predicate - let filter_operator = if !matches!(where_predicate, FilterPredicate::None) { - let mut filter_op = - FilterOperator::new(where_predicate.clone(), base_table_column_names.clone()); - filter_op.set_tracker(tracker.clone()); - Some(filter_op) - } else { - None - }; + // Compile the SELECT statement into a DBSP circuit + let circuit = Self::try_compile_circuit(&select_stmt, schema, &base_table)?; - // Check if this is an aggregated view - let is_aggregated = !group_by_columns.is_empty() || !aggregate_functions.is_empty(); - - // Create aggregate operator if needed - let aggregate_operator = if is_aggregated { - let mut agg_op = AggregateOperator::new( - group_by_columns, - aggregate_functions, - base_table_column_names.clone(), - ); - agg_op.set_tracker(tracker.clone()); - Some(agg_op) - } else { - None - }; - - // Only create project operator for non-aggregated views - let project_operator = if !is_aggregated { - let mut proj_op = ProjectOperator::from_select( - &select_stmt, - base_table_column_names.clone(), - schema, - )?; - proj_op.set_tracker(tracker.clone()); - Some(proj_op) - } else { - None - }; + // Circuit will be initialized when we first call merge_delta + let circuit_initialized = false; Ok(Self { - stream: RowKeyStream::from_zset(zset), + stream: RowKeyStream::from_zset(RowKeyZSet::new()), name, records, where_predicate, select_stmt, - filter_operator, - project_operator, - aggregate_operator, + circuit, + circuit_initialized, base_table, columns, populate_state: PopulateState::Start, @@ -338,46 +280,28 @@ impl IncrementalView { // Get the base table from referenced tables let table = &self.base_table; - // Build column list for SELECT clause - let select_columns = if let Some(ref project_op) = self.project_operator { - // Get the columns used by the projection operator - let mut columns = Vec::new(); - for col in project_op.columns() { - // Check if it's a simple column reference - if let turso_parser::ast::Expr::Id(name) = &col.expr { - columns.push(name.as_str().to_string()); - } else { - // For expressions, we need all columns (for now) - columns.clear(); - columns.push("*".to_string()); - break; - } - } - if columns.is_empty() || columns.contains(&"*".to_string()) { - "*".to_string() - } else { - // Add the columns and always include rowid - columns.join(", ").to_string() - } - } else { - // No projection, use all columns + // Check if the table has a rowid alias (INTEGER PRIMARY KEY column) + let has_rowid_alias = table.columns.iter().any(|col| col.is_rowid_alias); + + // For now, select all columns since we don't have the static operators + // The circuit will handle filtering and projection + // If there's a rowid alias, we don't need to select rowid separately + let select_clause = if has_rowid_alias { "*".to_string() + } else { + "*, rowid".to_string() }; - // Build WHERE clause from filter operator - let where_clause = if let Some(ref filter_op) = self.filter_operator { - self.build_where_clause(filter_op.predicate())? - } else { - String::new() - }; + // Build WHERE clause from the where_predicate + let where_clause = self.build_where_clause(&self.where_predicate)?; // Construct the final query let query = if where_clause.is_empty() { - format!("SELECT {}, rowid FROM {}", select_columns, table.name) + format!("SELECT {} FROM {}", select_clause, table.name) } else { format!( - "SELECT {}, rowid FROM {} WHERE {}", - select_columns, table.name, where_clause + "SELECT {} FROM {} WHERE {}", + select_clause, table.name, where_clause ) }; Ok(query) @@ -494,20 +418,40 @@ impl IncrementalView { let all_values: Vec = row.get_values().cloned().collect(); - // The last value should be the rowid - let rowid = match all_values.last() { - Some(crate::types::Value::Integer(id)) => *id, - _ => { - // This shouldn't happen - rowid must be an integer - *rows_processed += 1; - batch_count += 1; - continue; - } + // Determine how to extract the rowid + // If there's a rowid alias (INTEGER PRIMARY KEY), the rowid is one of the columns + // Otherwise, it's the last value we explicitly selected + let (rowid, values) = if let Some((idx, _)) = + self.base_table.get_rowid_alias_column() + { + // The rowid is the value at the rowid alias column index + let rowid = match all_values.get(idx) { + Some(crate::types::Value::Integer(id)) => *id, + _ => { + // This shouldn't happen - rowid alias must be an integer + *rows_processed += 1; + batch_count += 1; + continue; + } + }; + // All values are table columns (no separate rowid was selected) + (rowid, all_values) + } else { + // The last value is the explicitly selected rowid + let rowid = match all_values.last() { + Some(crate::types::Value::Integer(id)) => *id, + _ => { + // This shouldn't happen - rowid must be an integer + *rows_processed += 1; + batch_count += 1; + continue; + } + }; + // Get all values except the rowid + let values = all_values[..all_values.len() - 1].to_vec(); + (rowid, values) }; - // Get all values except the rowid - let values = all_values[..all_values.len() - 1].to_vec(); - // Add to batch delta - let merge_delta handle filtering and aggregation batch_delta.insert(rowid, values); @@ -542,120 +486,6 @@ impl IncrementalView { } } - /// Extract GROUP BY columns and aggregate functions from SELECT statement - fn extract_aggregation_info( - select: &ast::Select, - ) -> (Vec, Vec, Vec) { - use turso_parser::ast::*; - - let mut group_by_columns = Vec::new(); - let mut aggregate_functions = Vec::new(); - let mut output_column_names = Vec::new(); - - if let OneSelect::Select { - ref group_by, - ref columns, - .. - } = select.body.select - { - // Extract GROUP BY columns - if let Some(group_by) = group_by { - for expr in &group_by.exprs { - if let Some(col_name) = extract_column_name_from_expr(expr) { - group_by_columns.push(col_name); - } - } - } - - // Extract aggregate functions and column names/aliases from SELECT list - for result_col in columns { - match result_col { - ResultColumn::Expr(expr, alias) => { - // Extract aggregate functions - let mut found_aggregates = Vec::new(); - Self::extract_aggregates_from_expr(expr, &mut found_aggregates); - - // Determine the output column name - let col_name = if let Some(As::As(alias_name)) = alias { - // Use the provided alias - alias_name.as_str().to_string() - } else if !found_aggregates.is_empty() { - // Use the default name from the aggregate function - found_aggregates[0].default_output_name() - } else if let Some(name) = extract_column_name_from_expr(expr) { - // Use the column name - name - } else { - // Fallback to a generic name - format!("column{}", output_column_names.len() + 1) - }; - - output_column_names.push(col_name); - aggregate_functions.extend(found_aggregates); - } - ResultColumn::Star => { - // For SELECT *, we'd need to know the base table columns - // This is handled elsewhere - } - ResultColumn::TableStar(_) => { - // Similar to Star, but for a specific table - } - } - } - } - - (group_by_columns, aggregate_functions, output_column_names) - } - - /// Recursively extract aggregate functions from an expression - fn extract_aggregates_from_expr( - expr: &ast::Expr, - aggregate_functions: &mut Vec, - ) { - use crate::function::Func; - use turso_parser::ast::*; - - match expr { - // Handle COUNT(*) and similar aggregate functions with * - Expr::FunctionCallStar { name, .. } => { - // FunctionCallStar is typically COUNT(*), which has 0 args - if let Ok(func) = Func::resolve_function(name.as_str(), 0) { - // Use the centralized mapping from operator.rs - // For COUNT(*), we pass None as the input column - if let Some(agg_func) = AggregateFunction::from_sql_function(&func, None) { - aggregate_functions.push(agg_func); - } - } - } - Expr::FunctionCall { name, args, .. } => { - // Regular function calls with arguments - let arg_count = args.len(); - - if let Ok(func) = Func::resolve_function(name.as_str(), arg_count) { - // Extract the input column if there's an argument - let input_column = if arg_count > 0 { - args.first().and_then(extract_column_name_from_expr) - } else { - None - }; - - // Use the centralized mapping from operator.rs - if let Some(agg_func) = - AggregateFunction::from_sql_function(&func, input_column) - { - aggregate_functions.push(agg_func); - } - } - } - // Recursively check binary expressions, etc. - Expr::Binary(left, _, right) => { - Self::extract_aggregates_from_expr(left, aggregate_functions); - Self::extract_aggregates_from_expr(right, aggregate_functions); - } - _ => {} - } - } - /// Extract JOIN information from SELECT statement #[allow(clippy::type_complexity)] pub fn extract_join_info( @@ -743,50 +573,36 @@ impl IncrementalView { /// Get current data merged with transaction state pub fn current_data(&self, tx_state: Option<&ViewTransactionState>) -> Vec<(i64, Vec)> { - // Start with committed records - if let Some(tx_state) = tx_state { - // processed_delta = input delta for now. Need to apply operations - let processed_delta = &tx_state.delta; + // Use circuit to process uncommitted changes + let mut uncommitted = DeltaSet::new(); + uncommitted.insert(self.base_table.name.clone(), tx_state.delta.clone()); - // For non-aggregation views, merge the processed delta with committed records - let mut result_map: BTreeMap> = self.records.clone(); - - for (row, weight) in &processed_delta.changes { - if *weight > 0 && self.apply_filter(&row.values) { - result_map.insert(row.rowid, row.values.clone()); - } else if *weight < 0 { - result_map.remove(&row.rowid); + // Execute with uncommitted changes (won't affect circuit state) + match self.circuit.execute(HashMap::new(), uncommitted) { + Ok(processed_delta) => { + // Merge processed delta with committed records + let mut result_map: BTreeMap> = self.records.clone(); + for (row, weight) in &processed_delta.changes { + if *weight > 0 { + result_map.insert(row.rowid, row.values.clone()); + } else if *weight < 0 { + result_map.remove(&row.rowid); + } + } + result_map.into_iter().collect() + } + Err(e) => { + // Return error or panic - no fallback + panic!("Failed to execute circuit with uncommitted data: {e:?}"); } } - - result_map.into_iter().collect() } else { // No transaction state: return committed records self.records.clone().into_iter().collect() } } - /// Apply filter operator to a delta if present and commit the changes - fn apply_filter_to_delta(&mut self, delta: Delta) -> Delta { - if let Some(ref mut filter_op) = self.filter_operator { - // Commit updates state and returns output - filter_op.commit(delta) - } else { - delta - } - } - - /// Apply aggregation operator to a delta if this is an aggregated view and commit the changes - fn apply_aggregation_to_delta(&mut self, delta: Delta) -> Delta { - if let Some(ref mut agg_op) = self.aggregate_operator { - // Commit updates state and returns output - agg_op.commit(delta) - } else { - delta - } - } - /// Merge a delta of changes into the view's current state pub fn merge_delta(&mut self, delta: &Delta) { // Early return if delta is empty @@ -794,16 +610,33 @@ impl IncrementalView { return; } - // Apply operators in pipeline - let mut current_delta = delta.clone(); - current_delta = self.apply_filter_to_delta(current_delta); + // Use the circuit to process the delta + let mut input_data = HashMap::new(); + input_data.insert(self.base_table.name.clone(), delta.clone()); - // Apply projection operator if present (for non-aggregated views) - if let Some(ref mut project_op) = self.project_operator { - current_delta = project_op.commit(current_delta); + // If circuit hasn't been initialized yet, initialize it first + // This happens during populate_from_table + if !self.circuit_initialized { + // Initialize the circuit with empty state + self.circuit + .initialize(HashMap::new()) + .expect("Failed to initialize circuit"); + self.circuit_initialized = true; } - current_delta = self.apply_aggregation_to_delta(current_delta); + // Execute the circuit to process the delta + let current_delta = match self.circuit.execute(input_data.clone(), DeltaSet::empty()) { + Ok(output) => { + // Commit the changes to the circuit's internal state + self.circuit + .commit(input_data) + .expect("Failed to commit to circuit"); + output + } + Err(e) => { + panic!("Failed to execute circuit: {e:?}"); + } + }; // Update records and stream with the processed delta let mut zset_delta = RowKeyZSet::new(); @@ -821,441 +654,3 @@ impl IncrementalView { self.stream.apply_delta(&zset_delta); } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::incremental::operator::{Delta, IncrementalOperator}; - use crate::schema::{BTreeTable, Column, Schema, Type}; - use crate::types::Value; - use std::sync::Arc; - fn create_test_schema() -> Schema { - let mut schema = Schema::new(false); - let table = BTreeTable { - root_page: 1, - name: "t".to_string(), - columns: vec![ - Column { - name: Some("a".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }, - Column { - name: Some("b".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }, - Column { - name: Some("c".to_string()), - ty: Type::Integer, - ty_str: "INTEGER".to_string(), - primary_key: false, - is_rowid_alias: false, - notnull: false, - default: None, - unique: false, - collation: None, - hidden: false, - }, - ], - primary_key_columns: vec![], - has_rowid: true, - is_strict: false, - unique_sets: None, - }; - schema.add_btree_table(Arc::new(table)); - schema - } - - #[test] - fn test_projection_simple_columns() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, b FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(10), Value::Integer(20), Value::Integer(30)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(10), Value::Integer(20)]); - } - - #[test] - fn test_projection_arithmetic_expression() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a * 2 as doubled FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(4), Value::Integer(2), Value::Integer(0)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(8)]); - } - - #[test] - fn test_projection_multiple_expressions() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a + b as sum, a - b as diff, c FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(10), Value::Integer(3), Value::Integer(7)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!( - output.values, - vec![Value::Integer(13), Value::Integer(7), Value::Integer(7),] - ); - } - - #[test] - fn test_projection_function_call() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT abs(a - 300) as abs_diff, b FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(255), Value::Integer(20), Value::Integer(30)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - // abs(255 - 300) = abs(-45) = 45 - assert_eq!(output.values, vec![Value::Integer(45), Value::Integer(20),]); - } - - #[test] - fn test_projection_mixed_columns_and_expressions() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, b * 2 as doubled, c, a + b + c as total FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(1), Value::Integer(5), Value::Integer(3)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!( - output.values, - vec![ - Value::Integer(1), - Value::Integer(10), - Value::Integer(3), - Value::Integer(9), - ] - ); - } - - #[test] - fn test_projection_complex_expression() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT (a * 2) + (b * 3) as weighted, c / 2 as half FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(5), Value::Integer(2), Value::Integer(10)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(16), Value::Integer(5),]); - } - - #[test] - fn test_projection_with_where_clause() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, a * 2 as doubled FROM t WHERE b > 2"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - assert!(view.filter_operator.is_some()); - - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(4), Value::Integer(3), Value::Integer(0)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - assert_eq!(output.values, vec![Value::Integer(4), Value::Integer(8),]); - } - - #[test] - fn test_projection_more_output_columns_than_input() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, b, a * 2 as doubled_a, b * 3 as tripled_b, a + b as sum, hex(c) as hex_c FROM t"; - - let view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.project_operator.is_some()); - let project_op = view.project_operator.as_ref().unwrap(); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(5), Value::Integer(2), Value::Integer(15)], - ); - - let mut temp_project = project_op.clone(); - temp_project.initialize(delta); - let result = temp_project.get_current_state(); - - let (output, _weight) = result.changes.first().unwrap(); - // 3 input columns -> 6 output columns - assert_eq!( - output.values, - vec![ - Value::Integer(5), // a - Value::Integer(2), // b - Value::Integer(10), // a * 2 - Value::Integer(6), // b * 3 - Value::Integer(7), // a + b - Value::Text("3135".into()), // hex(15) - SQLite converts to string "15" then hex encodes - ] - ); - } - - #[test] - fn test_aggregation_count_with_group_by() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, COUNT(*) FROM t GROUP BY a"; - - let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); - - // Verify the view has an aggregate operator - assert!(view.aggregate_operator.is_some()); - - // Insert some test data - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], - ); - delta.insert( - 2, - vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)], - ); - delta.insert( - 3, - vec![Value::Integer(1), Value::Integer(30), Value::Integer(300)], - ); - - // Process the delta - view.merge_delta(&delta); - - // Verify we only processed the 3 rows we inserted - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 3); - - // Check the aggregated results - let results = view.current_data(None); - - // Should have 2 groups: a=1 with count=2, a=2 with count=1 - assert_eq!(results.len(), 2); - - // Find the group with a=1 - let group1 = results - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(1)) - .unwrap(); - assert_eq!(group1.1[0], Value::Integer(1)); // a=1 - assert_eq!(group1.1[1], Value::Integer(2)); // COUNT(*)=2 - - // Find the group with a=2 - let group2 = results - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(2)) - .unwrap(); - assert_eq!(group2.1[0], Value::Integer(2)); // a=2 - assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1 - } - - #[test] - fn test_aggregation_sum_with_filter() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT SUM(b) FROM t WHERE a > 1"; - - let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); - - assert!(view.aggregate_operator.is_some()); - assert!(view.filter_operator.is_some()); - - let mut delta = Delta::new(); - delta.insert( - 1, - vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], - ); - delta.insert( - 2, - vec![Value::Integer(2), Value::Integer(20), Value::Integer(200)], - ); - delta.insert( - 3, - vec![Value::Integer(3), Value::Integer(30), Value::Integer(300)], - ); - - view.merge_delta(&delta); - - // Should filter all 3 rows - assert_eq!(view.tracker.lock().unwrap().filter_evaluations, 3); - // But only aggregate the 2 that passed the filter (a > 1) - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); - - let results = view.current_data(None); - - // Should have 1 row with sum of b where a > 1 - assert_eq!(results.len(), 1); - assert_eq!(results[0].1[0], Value::Integer(50)); // SUM(b) = 20 + 30 - } - - #[test] - fn test_aggregation_incremental_updates() { - let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT a, COUNT(*), SUM(b) FROM t GROUP BY a"; - - let mut view = IncrementalView::from_sql(sql, &schema).unwrap(); - - // Initial insert - let mut delta1 = Delta::new(); - delta1.insert( - 1, - vec![Value::Integer(1), Value::Integer(10), Value::Integer(100)], - ); - delta1.insert( - 2, - vec![Value::Integer(1), Value::Integer(20), Value::Integer(200)], - ); - - view.merge_delta(&delta1); - - // Verify we processed exactly 2 rows for the first batch - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); - - // Check initial state - let results1 = view.current_data(None); - assert_eq!(results1.len(), 1); - assert_eq!(results1[0].1[1], Value::Integer(2)); // COUNT(*)=2 - assert_eq!(results1[0].1[2], Value::Integer(30)); // SUM(b)=30 - - // Reset counter to track second batch separately - view.tracker.lock().unwrap().aggregation_updates = 0; - - // Add more data - let mut delta2 = Delta::new(); - delta2.insert( - 3, - vec![Value::Integer(1), Value::Integer(5), Value::Integer(300)], - ); - delta2.insert( - 4, - vec![Value::Integer(2), Value::Integer(15), Value::Integer(400)], - ); - - view.merge_delta(&delta2); - - // Should only process the 2 new rows, not recompute everything - assert_eq!(view.tracker.lock().unwrap().aggregation_updates, 2); - - // Check updated state - let results2 = view.current_data(None); - assert_eq!(results2.len(), 2); - - // Group a=1 - let group1 = results2 - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(1)) - .unwrap(); - assert_eq!(group1.1[1], Value::Integer(3)); // COUNT(*)=3 - assert_eq!(group1.1[2], Value::Integer(35)); // SUM(b)=35 - - // Group a=2 - let group2 = results2 - .iter() - .find(|(_, vals)| vals[0] == Value::Integer(2)) - .unwrap(); - assert_eq!(group2.1[1], Value::Integer(1)); // COUNT(*)=1 - assert_eq!(group2.1[2], Value::Integer(15)); // SUM(b)=15 - } -} From 143c84c4e0606eed147280ce72331213818a055b Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Tue, 26 Aug 2025 16:58:06 -0500 Subject: [PATCH 9/9] add tests for rollback of views. --- testing/materialized_views.test | 171 ++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 2b6a56be3..5d226b016 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -385,3 +385,174 @@ do_execsql_test_on_specific_db {:memory:} matview-projections { SELECT * from v; } {4|3|7|22|3 3|4|7|22|3} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-insert { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 15; + + SELECT * FROM v ORDER BY a; + + BEGIN; + INSERT INTO t VALUES (4, 40), (5, 50); + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT * FROM v ORDER BY a; +} {2|20 +3|30 +2|20 +3|30 +4|40 +5|50 +2|20 +3|30} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-delete { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30), (4, 40); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 15; + + SELECT * FROM v ORDER BY a; + + BEGIN; + DELETE FROM t WHERE a IN (2, 3); + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT * FROM v ORDER BY a; +} {2|20 +3|30 +4|40 +4|40 +2|20 +3|30 +4|40} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-update { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 10), (2, 20), (3, 30); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 15; + + SELECT * FROM v ORDER BY a; + + BEGIN; + UPDATE t SET b = 5 WHERE a = 2; + UPDATE t SET b = 35 WHERE a = 1; + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT * FROM v ORDER BY a; +} {2|20 +3|30 +1|35 +3|30 +2|20 +3|30} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-aggregation { + CREATE TABLE sales(product_id INTEGER, amount INTEGER); + INSERT INTO sales VALUES (1, 100), (1, 200), (2, 150), (2, 250); + + CREATE MATERIALIZED VIEW product_totals AS + SELECT product_id, SUM(amount) as total, COUNT(*) as cnt + FROM sales + GROUP BY product_id; + + SELECT * FROM product_totals ORDER BY product_id; + + BEGIN; + INSERT INTO sales VALUES (1, 50), (3, 300); + SELECT * FROM product_totals ORDER BY product_id; + ROLLBACK; + + SELECT * FROM product_totals ORDER BY product_id; +} {1|300|2 +2|400|2 +1|350|3 +2|400|2 +3|300|1 +1|300|2 +2|400|2} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-mixed-operations { + CREATE TABLE orders(id INTEGER PRIMARY KEY, customer INTEGER, amount INTEGER); + INSERT INTO orders VALUES (1, 100, 50), (2, 200, 75), (3, 100, 25); + + CREATE MATERIALIZED VIEW customer_totals AS + SELECT customer, SUM(amount) as total, COUNT(*) as cnt + FROM orders + GROUP BY customer; + + SELECT * FROM customer_totals ORDER BY customer; + + BEGIN; + INSERT INTO orders VALUES (4, 100, 100); + UPDATE orders SET amount = 150 WHERE id = 2; + DELETE FROM orders WHERE id = 3; + SELECT * FROM customer_totals ORDER BY customer; + ROLLBACK; + + SELECT * FROM customer_totals ORDER BY customer; +} {100|75|2 +200|75|1 +100|150|2 +200|150|1 +100|75|2 +200|75|1} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-filtered-aggregation { + CREATE TABLE transactions(id INTEGER, account INTEGER, amount INTEGER, type TEXT); + INSERT INTO transactions VALUES + (1, 100, 50, 'deposit'), + (2, 100, 30, 'withdraw'), + (3, 200, 100, 'deposit'), + (4, 200, 40, 'withdraw'); + + CREATE MATERIALIZED VIEW deposits AS + SELECT account, SUM(amount) as total_deposits, COUNT(*) as cnt + FROM transactions + WHERE type = 'deposit' + GROUP BY account; + + SELECT * FROM deposits ORDER BY account; + + BEGIN; + INSERT INTO transactions VALUES (5, 100, 75, 'deposit'); + UPDATE transactions SET amount = 60 WHERE id = 1; + DELETE FROM transactions WHERE id = 3; + SELECT * FROM deposits ORDER BY account; + ROLLBACK; + + SELECT * FROM deposits ORDER BY account; +} {100|50|1 +200|100|1 +100|135|2 +100|50|1 +200|100|1} + +do_execsql_test_on_specific_db {:memory:} matview-rollback-empty-view { + CREATE TABLE t(a INTEGER, b INTEGER); + INSERT INTO t VALUES (1, 5), (2, 8); + + CREATE MATERIALIZED VIEW v AS + SELECT * FROM t WHERE b > 10; + + SELECT COUNT(*) FROM v; + + BEGIN; + INSERT INTO t VALUES (3, 15), (4, 20); + SELECT * FROM v ORDER BY a; + ROLLBACK; + + SELECT COUNT(*) FROM v; +} {0 +3|15 +4|20 +0}