No need for ResultSetColumn to be an enum

This commit is contained in:
jussisaurio
2024-11-25 18:25:09 +02:00
parent bb8ba7fb01
commit ac12e9c7fd
3 changed files with 97 additions and 166 deletions

View File

@@ -921,38 +921,34 @@ fn inner_loop_source_emit(
// should be emitted in the SELECT clause order, not the ORDER BY clause order.
let mut result_columns_to_skip: Option<Vec<usize>> = None;
for (i, rc) in result_columns.iter().enumerate() {
match rc {
ResultSetColumn::Expr {
expr,
contains_aggregates,
} => {
assert!(!*contains_aggregates);
let found = order_by.iter().enumerate().find(|(_, (e, _))| e == expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
if !rc.contains_aggregates {
let found = order_by
.iter()
.enumerate()
.find(|(_, (e, _))| e == &rc.expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
}
ResultSetColumn::Agg(agg) => {
// TODO: implement a custom equality check for expressions
// there are lots of examples where this breaks, even simple ones like
// sum(x) != SUM(x)
let found = order_by
.iter()
.enumerate()
.find(|(_, (expr, _))| expr == &agg.original_expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
} else {
// TODO: implement a custom equality check for expressions
// there are lots of examples where this breaks, even simple ones like
// sum(x) != SUM(x)
let found = order_by
.iter()
.enumerate()
.find(|(_, (expr, _))| expr == &rc.expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
}
}
}
@@ -976,16 +972,8 @@ fn inner_loop_source_emit(
continue;
}
}
match rc {
ResultSetColumn::Expr {
expr,
contains_aggregates,
} => {
assert!(!*contains_aggregates);
translate_expr(program, Some(referenced_tables), expr, cur_reg, None)?;
}
other => unreachable!("{:?}", other),
}
assert!(!rc.contains_aggregates);
translate_expr(program, Some(referenced_tables), &rc.expr, cur_reg, None)?;
m.result_column_indexes_in_orderby_sorter
.insert(i, cur_idx_in_orderby_sorter);
cur_idx_in_orderby_sorter += 1;
@@ -1014,43 +1002,23 @@ fn inner_loop_source_emit(
let reg = start_reg + i;
translate_aggregation(program, referenced_tables, agg, reg)?;
}
for (i, expr) in result_columns.iter().enumerate() {
match expr {
ResultSetColumn::Expr {
expr,
contains_aggregates,
} => {
if *contains_aggregates {
// Do nothing, aggregates will be computed above and this full result expression will be
// computed later
continue;
}
let reg = start_reg + num_aggs + i;
translate_expr(program, Some(referenced_tables), expr, reg, None)?;
}
ResultSetColumn::Agg(_) => { /* do nothing, aggregates are computed above */ }
for (i, rc) in result_columns.iter().enumerate() {
if rc.contains_aggregates {
// Do nothing, aggregates are computed above
continue;
}
let reg = start_reg + num_aggs + i;
translate_expr(program, Some(referenced_tables), &rc.expr, reg, None)?;
}
Ok(())
}
InnerLoopEmitTarget::ResultRow { limit } => {
assert!(aggregates.is_none());
let start_reg = program.alloc_registers(result_columns.len());
for (i, expr) in result_columns.iter().enumerate() {
match expr {
ResultSetColumn::Expr {
expr,
contains_aggregates,
} => {
assert!(!*contains_aggregates);
let reg = start_reg + i;
translate_expr(program, Some(referenced_tables), expr, reg, None)?;
}
other => unreachable!(
"Unexpected non-scalar result column in inner loop: {:?}",
other
),
}
for (i, rc) in result_columns.iter().enumerate() {
assert!(!rc.contains_aggregates);
let reg = start_reg + i;
translate_expr(program, Some(referenced_tables), &rc.expr, reg, None)?;
}
emit_result_row(
program,
@@ -1483,34 +1451,34 @@ fn group_by_emit(
let mut result_columns_to_skip: Option<Vec<usize>> = None;
if let Some(order_by) = order_by {
for (i, rc) in result_columns.iter().enumerate() {
match rc {
ResultSetColumn::Expr { expr, .. } => {
let found = order_by.iter().enumerate().find(|(_, (e, _))| e == expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
if !rc.contains_aggregates {
let found = order_by
.iter()
.enumerate()
.find(|(_, (e, _))| e == &rc.expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
}
ResultSetColumn::Agg(agg) => {
// TODO: implement a custom equality check for expressions
// there are lots of examples where this breaks, even simple ones like
// sum(x) != SUM(x)
let found = order_by
.iter()
.enumerate()
.find(|(_, (expr, _))| expr == &agg.original_expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
} else {
// TODO: implement a custom equality check for expressions
// there are lots of examples where this breaks, even simple ones like
// sum(x) != SUM(x)
let found = order_by
.iter()
.enumerate()
.find(|(_, (expr, _))| expr == &rc.expr);
if let Some((j, _)) = found {
if let Some(ref mut v) = result_columns_to_skip {
v.push(i);
} else {
result_columns_to_skip = Some(vec![i]);
}
m.result_column_indexes_in_orderby_sorter.insert(i, j);
}
}
}
@@ -1542,29 +1510,13 @@ fn group_by_emit(
continue;
}
}
match rc {
ResultSetColumn::Expr { expr, .. } => {
translate_expr(
program,
Some(referenced_tables),
expr,
cur_reg,
Some(&precomputed_exprs_to_register),
)?;
}
ResultSetColumn::Agg(agg) => {
let found = aggregates.iter().enumerate().find(|(_, a)| **a == *agg);
if let Some((i, _)) = found {
program.emit_insn(Insn::Copy {
src_reg: agg_start_reg + i,
dst_reg: cur_reg,
amount: 0,
});
} else {
unreachable!("agg {:?} not found", agg);
}
}
}
translate_expr(
program,
Some(referenced_tables),
&rc.expr,
cur_reg,
Some(&precomputed_exprs_to_register),
)?;
m.result_column_indexes_in_orderby_sorter
.insert(i, res_col_idx_in_orderby_sorter);
res_col_idx_in_orderby_sorter += 1;
@@ -1647,29 +1599,13 @@ fn agg_without_group_by_emit(
let output_reg = program.alloc_registers(result_columns.len());
for (i, rc) in result_columns.iter().enumerate() {
match rc {
ResultSetColumn::Expr { expr, .. } => {
translate_expr(
program,
Some(referenced_tables),
expr,
output_reg + i,
Some(&precomputed_exprs_to_register),
)?;
}
ResultSetColumn::Agg(agg) => {
let found = aggregates.iter().enumerate().find(|(_, a)| **a == *agg);
if let Some((i, _)) = found {
program.emit_insn(Insn::Copy {
src_reg: agg_start_reg + i,
dst_reg: output_reg + i,
amount: 0,
});
} else {
unreachable!("agg {:?} not found", agg);
}
}
}
translate_expr(
program,
Some(referenced_tables),
&rc.expr,
output_reg + i,
Some(&precomputed_exprs_to_register),
)?;
}
// This always emits a ResultRow because currently it can only be used for a single row result
emit_result_row(program, output_reg, result_columns.len(), None);
@@ -1698,17 +1634,14 @@ fn order_by_emit(
ty: crate::schema::Type::Null,
});
}
for (i, expr) in result_columns.iter().enumerate() {
for (i, rc) in result_columns.iter().enumerate() {
if let Some(ref v) = m.result_columns_to_skip_in_orderby_sorter {
if v.contains(&i) {
continue;
}
}
pseudo_columns.push(Column {
name: match expr {
ResultSetColumn::Expr { expr, .. } => expr.to_string(),
ResultSetColumn::Agg(agg) => agg.to_string(),
},
name: rc.expr.to_string(),
primary_key: false,
ty: crate::schema::Type::Null,
});

View File

@@ -13,12 +13,10 @@ use crate::{
};
#[derive(Debug)]
pub enum ResultSetColumn {
Expr {
expr: ast::Expr,
contains_aggregates: bool,
},
Agg(Aggregate),
pub struct ResultSetColumn {
pub expr: ast::Expr,
// TODO: encode which aggregates (e.g. index bitmask of plan.aggregates) are present in this column
pub contains_aggregates: bool,
}
#[derive(Debug)]

View File

@@ -60,6 +60,7 @@ fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec<Aggregate>) {
resolve_aggregates(lhs, aggs);
resolve_aggregates(rhs, aggs);
}
// TODO: handle other expressions that may contain aggregates
_ => {}
}
}
@@ -272,7 +273,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
ast::ResultColumn::Star => {
for table_reference in plan.referenced_tables.iter() {
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns.push(ResultSetColumn::Expr {
plan.result_columns.push(ResultSetColumn {
expr: ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
@@ -296,7 +297,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
}
let table_reference = referenced_table.unwrap();
for (idx, col) in table_reference.table.columns.iter().enumerate() {
plan.result_columns.push(ResultSetColumn::Expr {
plan.result_columns.push(ResultSetColumn {
expr: ast::Expr::Column {
database: None, // TODO: support different databases
table: table_reference.table_index,
@@ -333,14 +334,17 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
original_expr: expr.clone(),
};
aggregate_expressions.push(agg.clone());
plan.result_columns.push(ResultSetColumn::Agg(agg));
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates: true,
});
}
Ok(_) => {
let cur_agg_count = aggregate_expressions.len();
resolve_aggregates(&expr, &mut aggregate_expressions);
let contains_aggregates =
cur_agg_count != aggregate_expressions.len();
plan.result_columns.push(ResultSetColumn::Expr {
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates,
});
@@ -364,7 +368,10 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
original_expr: expr.clone(),
};
aggregate_expressions.push(agg.clone());
plan.result_columns.push(ResultSetColumn::Agg(agg));
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates: true,
});
} else {
crate::bail_parse_error!(
"Invalid aggregate function: {}",
@@ -372,23 +379,16 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result<P
);
}
}
ast::Expr::Binary(lhs, _, rhs) => {
expr => {
let cur_agg_count = aggregate_expressions.len();
resolve_aggregates(&lhs, &mut aggregate_expressions);
resolve_aggregates(&rhs, &mut aggregate_expressions);
resolve_aggregates(expr, &mut aggregate_expressions);
let contains_aggregates =
cur_agg_count != aggregate_expressions.len();
plan.result_columns.push(ResultSetColumn::Expr {
plan.result_columns.push(ResultSetColumn {
expr: expr.clone(),
contains_aggregates,
});
}
e => {
plan.result_columns.push(ResultSetColumn::Expr {
expr: e.clone(),
contains_aggregates: false,
});
}
}
}
}