diff --git a/core/translate.rs b/core/translate.rs index c48dc513b..e21d7f469 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -454,7 +454,23 @@ fn translate_aggregation( }); target_register } - AggFunc::Count => todo!(), + AggFunc::Count => { + let expr_reg = if args.is_empty() { + program.alloc_register() + } else { + let expr = &args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, cursor_id, table, expr, expr_reg); + expr_reg + }; + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + func: AggFunc::Count, + }); + target_register + } + AggFunc::GroupConcat => todo!(), AggFunc::Max => todo!(), AggFunc::Min => todo!(), diff --git a/core/types.rs b/core/types.rs index 4bfce889c..72fa41015 100644 --- a/core/types.rs +++ b/core/types.rs @@ -46,6 +46,7 @@ impl Display for OwnedValue { OwnedValue::Agg(a) => match a.as_ref() { AggContext::Avg(acc, _count) => write!(f, "{}", acc), AggContext::Sum(acc) => write!(f, "{}", acc), + AggContext::Count(count) => write!(f, "{}", count), }, OwnedValue::Record(r) => write!(f, "{:?}", r), } @@ -56,6 +57,7 @@ impl Display for OwnedValue { pub enum AggContext { Avg(OwnedValue, OwnedValue), // acc and count Sum(OwnedValue), + Count(OwnedValue), } impl std::ops::Add for OwnedValue { @@ -160,6 +162,7 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> { OwnedValue::Agg(a) => match a.as_ref() { AggContext::Avg(acc, _count) => to_value(acc), // we assume aggfinal was called AggContext::Sum(acc) => to_value(acc), + AggContext::Count(count) => to_value(count), }, OwnedValue::Record(_) => todo!(), } diff --git a/core/vdbe.rs b/core/vdbe.rs index e2200660d..ddc8865a3 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -430,6 +430,9 @@ impl Program { AggFunc::Sum => { OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0)))) } + AggFunc::Count => { + OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0)))) + } _ => { todo!(); } @@ -459,6 +462,16 @@ impl Program { }; *acc += col; } + AggFunc::Count => { + let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() + else { + unreachable!(); + }; + let AggContext::Count(count) = agg.borrow_mut() else { + unreachable!(); + }; + *count += 1; + } _ => { todo!(); } @@ -478,6 +491,7 @@ impl Program { *acc /= count.clone(); } AggFunc::Sum => {} + AggFunc::Count => {} _ => { todo!(); } diff --git a/testing/all.test b/testing/all.test index 2b07aecf6..b1fd6cb95 100755 --- a/testing/all.test +++ b/testing/all.test @@ -43,6 +43,10 @@ do_execsql_test select-limit { SELECT id FROM users LIMIT 1; } {1} +do_execsql_test select-count { + SELECT count(id) FROM users; +} {10000} + do_execsql_test select-limit-0 { SELECT id FROM users LIMIT 0; } {}