From ac12e9c7fd4844677c973bdb58fb53d14c58da23 Mon Sep 17 00:00:00 2001 From: jussisaurio Date: Mon, 25 Nov 2024 18:25:09 +0200 Subject: [PATCH] No need for ResultSetColumn to be an enum --- core/translate/emitter.rs | 223 +++++++++++++------------------------- core/translate/plan.rs | 10 +- core/translate/planner.rs | 30 ++--- 3 files changed, 97 insertions(+), 166 deletions(-) diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 78a3c5a0f..9f292f625 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -921,38 +921,34 @@ fn inner_loop_source_emit( // should be emitted in the SELECT clause order, not the ORDER BY clause order. let mut result_columns_to_skip: Option> = None; for (i, rc) in result_columns.iter().enumerate() { - match rc { - ResultSetColumn::Expr { - expr, - contains_aggregates, - } => { - assert!(!*contains_aggregates); - let found = order_by.iter().enumerate().find(|(_, (e, _))| e == expr); - if let Some((j, _)) = found { - if let Some(ref mut v) = result_columns_to_skip { - v.push(i); - } else { - result_columns_to_skip = Some(vec![i]); - } - m.result_column_indexes_in_orderby_sorter.insert(i, j); + if !rc.contains_aggregates { + let found = order_by + .iter() + .enumerate() + .find(|(_, (e, _))| e == &rc.expr); + if let Some((j, _)) = found { + if let Some(ref mut v) = result_columns_to_skip { + v.push(i); + } else { + result_columns_to_skip = Some(vec![i]); } + m.result_column_indexes_in_orderby_sorter.insert(i, j); } - ResultSetColumn::Agg(agg) => { - // TODO: implement a custom equality check for expressions - // there are lots of examples where this breaks, even simple ones like - // sum(x) != SUM(x) - let found = order_by - .iter() - .enumerate() - .find(|(_, (expr, _))| expr == &agg.original_expr); - if let Some((j, _)) = found { - if let Some(ref mut v) = result_columns_to_skip { - v.push(i); - } else { - result_columns_to_skip = Some(vec![i]); - } - m.result_column_indexes_in_orderby_sorter.insert(i, j); + } else { + // TODO: implement a custom equality check for expressions + // there are lots of examples where this breaks, even simple ones like + // sum(x) != SUM(x) + let found = order_by + .iter() + .enumerate() + .find(|(_, (expr, _))| expr == &rc.expr); + if let Some((j, _)) = found { + if let Some(ref mut v) = result_columns_to_skip { + v.push(i); + } else { + result_columns_to_skip = Some(vec![i]); } + m.result_column_indexes_in_orderby_sorter.insert(i, j); } } } @@ -976,16 +972,8 @@ fn inner_loop_source_emit( continue; } } - match rc { - ResultSetColumn::Expr { - expr, - contains_aggregates, - } => { - assert!(!*contains_aggregates); - translate_expr(program, Some(referenced_tables), expr, cur_reg, None)?; - } - other => unreachable!("{:?}", other), - } + assert!(!rc.contains_aggregates); + translate_expr(program, Some(referenced_tables), &rc.expr, cur_reg, None)?; m.result_column_indexes_in_orderby_sorter .insert(i, cur_idx_in_orderby_sorter); cur_idx_in_orderby_sorter += 1; @@ -1014,43 +1002,23 @@ fn inner_loop_source_emit( let reg = start_reg + i; translate_aggregation(program, referenced_tables, agg, reg)?; } - for (i, expr) in result_columns.iter().enumerate() { - match expr { - ResultSetColumn::Expr { - expr, - contains_aggregates, - } => { - if *contains_aggregates { - // Do nothing, aggregates will be computed above and this full result expression will be - // computed later - continue; - } - let reg = start_reg + num_aggs + i; - translate_expr(program, Some(referenced_tables), expr, reg, None)?; - } - ResultSetColumn::Agg(_) => { /* do nothing, aggregates are computed above */ } + for (i, rc) in result_columns.iter().enumerate() { + if rc.contains_aggregates { + // Do nothing, aggregates are computed above + continue; } + let reg = start_reg + num_aggs + i; + translate_expr(program, Some(referenced_tables), &rc.expr, reg, None)?; } Ok(()) } InnerLoopEmitTarget::ResultRow { limit } => { assert!(aggregates.is_none()); let start_reg = program.alloc_registers(result_columns.len()); - for (i, expr) in result_columns.iter().enumerate() { - match expr { - ResultSetColumn::Expr { - expr, - contains_aggregates, - } => { - assert!(!*contains_aggregates); - let reg = start_reg + i; - translate_expr(program, Some(referenced_tables), expr, reg, None)?; - } - other => unreachable!( - "Unexpected non-scalar result column in inner loop: {:?}", - other - ), - } + for (i, rc) in result_columns.iter().enumerate() { + assert!(!rc.contains_aggregates); + let reg = start_reg + i; + translate_expr(program, Some(referenced_tables), &rc.expr, reg, None)?; } emit_result_row( program, @@ -1483,34 +1451,34 @@ fn group_by_emit( let mut result_columns_to_skip: Option> = None; if let Some(order_by) = order_by { for (i, rc) in result_columns.iter().enumerate() { - match rc { - ResultSetColumn::Expr { expr, .. } => { - let found = order_by.iter().enumerate().find(|(_, (e, _))| e == expr); - if let Some((j, _)) = found { - if let Some(ref mut v) = result_columns_to_skip { - v.push(i); - } else { - result_columns_to_skip = Some(vec![i]); - } - m.result_column_indexes_in_orderby_sorter.insert(i, j); + if !rc.contains_aggregates { + let found = order_by + .iter() + .enumerate() + .find(|(_, (e, _))| e == &rc.expr); + if let Some((j, _)) = found { + if let Some(ref mut v) = result_columns_to_skip { + v.push(i); + } else { + result_columns_to_skip = Some(vec![i]); } + m.result_column_indexes_in_orderby_sorter.insert(i, j); } - ResultSetColumn::Agg(agg) => { - // TODO: implement a custom equality check for expressions - // there are lots of examples where this breaks, even simple ones like - // sum(x) != SUM(x) - let found = order_by - .iter() - .enumerate() - .find(|(_, (expr, _))| expr == &agg.original_expr); - if let Some((j, _)) = found { - if let Some(ref mut v) = result_columns_to_skip { - v.push(i); - } else { - result_columns_to_skip = Some(vec![i]); - } - m.result_column_indexes_in_orderby_sorter.insert(i, j); + } else { + // TODO: implement a custom equality check for expressions + // there are lots of examples where this breaks, even simple ones like + // sum(x) != SUM(x) + let found = order_by + .iter() + .enumerate() + .find(|(_, (expr, _))| expr == &rc.expr); + if let Some((j, _)) = found { + if let Some(ref mut v) = result_columns_to_skip { + v.push(i); + } else { + result_columns_to_skip = Some(vec![i]); } + m.result_column_indexes_in_orderby_sorter.insert(i, j); } } } @@ -1542,29 +1510,13 @@ fn group_by_emit( continue; } } - match rc { - ResultSetColumn::Expr { expr, .. } => { - translate_expr( - program, - Some(referenced_tables), - expr, - cur_reg, - Some(&precomputed_exprs_to_register), - )?; - } - ResultSetColumn::Agg(agg) => { - let found = aggregates.iter().enumerate().find(|(_, a)| **a == *agg); - if let Some((i, _)) = found { - program.emit_insn(Insn::Copy { - src_reg: agg_start_reg + i, - dst_reg: cur_reg, - amount: 0, - }); - } else { - unreachable!("agg {:?} not found", agg); - } - } - } + translate_expr( + program, + Some(referenced_tables), + &rc.expr, + cur_reg, + Some(&precomputed_exprs_to_register), + )?; m.result_column_indexes_in_orderby_sorter .insert(i, res_col_idx_in_orderby_sorter); res_col_idx_in_orderby_sorter += 1; @@ -1647,29 +1599,13 @@ fn agg_without_group_by_emit( let output_reg = program.alloc_registers(result_columns.len()); for (i, rc) in result_columns.iter().enumerate() { - match rc { - ResultSetColumn::Expr { expr, .. } => { - translate_expr( - program, - Some(referenced_tables), - expr, - output_reg + i, - Some(&precomputed_exprs_to_register), - )?; - } - ResultSetColumn::Agg(agg) => { - let found = aggregates.iter().enumerate().find(|(_, a)| **a == *agg); - if let Some((i, _)) = found { - program.emit_insn(Insn::Copy { - src_reg: agg_start_reg + i, - dst_reg: output_reg + i, - amount: 0, - }); - } else { - unreachable!("agg {:?} not found", agg); - } - } - } + translate_expr( + program, + Some(referenced_tables), + &rc.expr, + output_reg + i, + Some(&precomputed_exprs_to_register), + )?; } // This always emits a ResultRow because currently it can only be used for a single row result emit_result_row(program, output_reg, result_columns.len(), None); @@ -1698,17 +1634,14 @@ fn order_by_emit( ty: crate::schema::Type::Null, }); } - for (i, expr) in result_columns.iter().enumerate() { + for (i, rc) in result_columns.iter().enumerate() { if let Some(ref v) = m.result_columns_to_skip_in_orderby_sorter { if v.contains(&i) { continue; } } pseudo_columns.push(Column { - name: match expr { - ResultSetColumn::Expr { expr, .. } => expr.to_string(), - ResultSetColumn::Agg(agg) => agg.to_string(), - }, + name: rc.expr.to_string(), primary_key: false, ty: crate::schema::Type::Null, }); diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 6ab599372..ac75981f6 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -13,12 +13,10 @@ use crate::{ }; #[derive(Debug)] -pub enum ResultSetColumn { - Expr { - expr: ast::Expr, - contains_aggregates: bool, - }, - Agg(Aggregate), +pub struct ResultSetColumn { + pub expr: ast::Expr, + // TODO: encode which aggregates (e.g. index bitmask of plan.aggregates) are present in this column + pub contains_aggregates: bool, } #[derive(Debug)] diff --git a/core/translate/planner.rs b/core/translate/planner.rs index c2c5aa115..51706f108 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -60,6 +60,7 @@ fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec) { resolve_aggregates(lhs, aggs); resolve_aggregates(rhs, aggs); } + // TODO: handle other expressions that may contain aggregates _ => {} } } @@ -272,7 +273,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

{ for table_reference in plan.referenced_tables.iter() { for (idx, col) in table_reference.table.columns.iter().enumerate() { - plan.result_columns.push(ResultSetColumn::Expr { + plan.result_columns.push(ResultSetColumn { expr: ast::Expr::Column { database: None, // TODO: support different databases table: table_reference.table_index, @@ -296,7 +297,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

(schema: &Schema, select: ast::Select) -> Result

{ let cur_agg_count = aggregate_expressions.len(); resolve_aggregates(&expr, &mut aggregate_expressions); let contains_aggregates = cur_agg_count != aggregate_expressions.len(); - plan.result_columns.push(ResultSetColumn::Expr { + plan.result_columns.push(ResultSetColumn { expr: expr.clone(), contains_aggregates, }); @@ -364,7 +368,10 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

(schema: &Schema, select: ast::Select) -> Result

{ + expr => { let cur_agg_count = aggregate_expressions.len(); - resolve_aggregates(&lhs, &mut aggregate_expressions); - resolve_aggregates(&rhs, &mut aggregate_expressions); + resolve_aggregates(expr, &mut aggregate_expressions); let contains_aggregates = cur_agg_count != aggregate_expressions.len(); - plan.result_columns.push(ResultSetColumn::Expr { + plan.result_columns.push(ResultSetColumn { expr: expr.clone(), contains_aggregates, }); } - e => { - plan.result_columns.push(ResultSetColumn::Expr { - expr: e.clone(), - contains_aggregates: false, - }); - } } } }