From 8a9eb74f9b8557acf4bb16045693385165f0789f Mon Sep 17 00:00:00 2001 From: Bennett Clement Date: Thu, 11 Jul 2024 23:40:55 +0800 Subject: [PATCH] Implement total() aggregation function - Returns 0.0 when called on non integer / non float columns - Always returns floating point number - fix: default for sum() should be NULL when there is no non-NULL row per docs --- core/translate.rs | 15 ++++++++++++++- core/types.rs | 4 +++- core/vdbe.rs | 21 +++++++++++++++++---- testing/all.test | 4 ++++ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/core/translate.rs b/core/translate.rs index b30d75811..09c96feb3 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -653,7 +653,20 @@ fn translate_aggregation( }); target_register } - AggFunc::Total => todo!(), + AggFunc::Total => { + if args.len() != 1 { + anyhow::bail!("Parse error: total bad number of arguments"); + } + let expr = &args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, select, expr, expr_reg)?; + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + func: AggFunc::Total, + }); + target_register + } }; Ok(dest) } diff --git a/core/types.rs b/core/types.rs index 87f80fd2f..5f99f1482 100644 --- a/core/types.rs +++ b/core/types.rs @@ -81,7 +81,9 @@ impl std::ops::Add for OwnedValue { (OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => { OwnedValue::Float(float_left + float_right) } - _ => unreachable!(), + (lhs, OwnedValue::Null) => lhs, + (OwnedValue::Null, rhs) => rhs, + _ => OwnedValue::Float(0.0), } } } diff --git a/core/vdbe.rs b/core/vdbe.rs index 907f85b25..27fb65d95 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -445,7 +445,13 @@ impl Program { OwnedValue::Integer(0), ))), AggFunc::Sum => { - OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Integer(0)))) + OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Null))) + } + AggFunc::Total => { + // The result of total() is always a floating point value. + // No overflow error is ever raised if any prior input was a floating point value. + // Total() never throws an integer overflow. + OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0)))) } AggFunc::Count => { OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0)))) @@ -496,7 +502,7 @@ impl Program { *acc += col; *count += 1; } - AggFunc::Sum => { + AggFunc::Sum | AggFunc::Total => { let col = state.registers[*col].clone(); let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() else { @@ -599,7 +605,7 @@ impl Program { }; *acc /= count.clone(); } - AggFunc::Sum => {} + AggFunc::Sum | AggFunc::Total => {} AggFunc::Count => {} AggFunc::Max => {} AggFunc::Min => {} @@ -973,7 +979,14 @@ fn insn_to_str(addr: BranchOffset, insn: &Insn, indent: String) -> String { }; format!( "{:<4} {:<17} {:<4} {:<4} {:<4} {:<13} {:<2} {}", - addr, &(indent + opcode), p1, p2, p3, p4.to_string(), p5, comment + addr, + &(indent + opcode), + p1, + p2, + p3, + p4.to_string(), + p5, + comment ) } diff --git a/testing/all.test b/testing/all.test index 61a5f8d89..91fe87ec0 100755 --- a/testing/all.test +++ b/testing/all.test @@ -41,6 +41,10 @@ do_execsql_test select-sum { SELECT sum(age) FROM users; } {503960} +do_execsql_test select-total { + SELECT sum(age) FROM users; +} {503960} + do_execsql_test select-limit { SELECT id FROM users LIMIT 1; } {1}