diff --git a/core/translate.rs b/core/translate.rs index 2ace4a36c..2dbe09364 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -127,7 +127,7 @@ fn translate_select(schema: &Schema, select: Select) -> Result { AggregationFunc::Max => todo!(), AggregationFunc::Min => todo!(), AggregationFunc::StringAgg => todo!(), - AggregationFunc::Sum => todo!(), + AggregationFunc::Sum => AggFunc::Sum, AggregationFunc::Total => todo!(), }; program.emit_insn(Insn::AggFinal { @@ -453,7 +453,20 @@ fn translate_aggregation( AggregationFunc::Max => todo!(), AggregationFunc::Min => todo!(), AggregationFunc::StringAgg => todo!(), - AggregationFunc::Sum => todo!(), + AggregationFunc::Sum => { + if args.len() != 1 { + anyhow::bail!("Parse error: sum bad number of arguments"); + } + let expr = &args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, cursor_id, table, &expr, expr_reg); + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + func: crate::vdbe::AggFunc::Sum, + }); + target_register + } AggregationFunc::Total => todo!(), }; Ok(dest) diff --git a/core/types.rs b/core/types.rs index 86ef35beb..c2a1b5f04 100644 --- a/core/types.rs +++ b/core/types.rs @@ -11,11 +11,6 @@ pub enum Value<'a> { Blob(&'a Vec), } -#[derive(Debug, Clone, PartialEq)] -pub enum AggContext { - Avg(f64, usize), // acc and count -} - #[derive(Debug, Clone, PartialEq)] pub enum OwnedValue { Null, @@ -23,7 +18,105 @@ pub enum OwnedValue { Float(f64), Text(Rc), Blob(Rc>), - Agg(Box), + Agg(Box), // TODO(pere): make this without Box. Currently this might cause cache miss but let's leave it for future analysis +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AggContext { + Avg(OwnedValue, OwnedValue), // acc and count + Sum(OwnedValue), +} + +impl std::ops::Add for OwnedValue { + type Output = OwnedValue; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (OwnedValue::Integer(int_left), OwnedValue::Integer(int_right)) => { + OwnedValue::Integer(int_left + int_right) + } + (OwnedValue::Integer(int_left), OwnedValue::Float(float_right)) => { + OwnedValue::Float(int_left as f64 + float_right) + } + (OwnedValue::Float(float_left), OwnedValue::Integer(int_right)) => { + OwnedValue::Float(float_left + int_right as f64) + } + (OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => { + OwnedValue::Float(float_left + float_right) + } + _ => unreachable!(), + } + } +} + +impl std::ops::Add for OwnedValue { + type Output = OwnedValue; + + fn add(self, rhs: f64) -> Self::Output { + match self { + OwnedValue::Integer(int_left) => OwnedValue::Float(int_left as f64 + rhs), + OwnedValue::Float(float_left) => OwnedValue::Float(float_left + rhs), + _ => unreachable!(), + } + } +} + +impl std::ops::Add for OwnedValue { + type Output = OwnedValue; + + fn add(self, rhs: i64) -> Self::Output { + match self { + OwnedValue::Integer(int_left) => OwnedValue::Integer(int_left + rhs), + OwnedValue::Float(float_left) => OwnedValue::Float(float_left + rhs as f64), + _ => unreachable!(), + } + } +} + +impl std::ops::AddAssign for OwnedValue { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs; + } +} + +impl std::ops::AddAssign for OwnedValue { + fn add_assign(&mut self, rhs: i64) { + *self = self.clone() + rhs; + } +} + +impl std::ops::AddAssign for OwnedValue { + fn add_assign(&mut self, rhs: f64) { + *self = self.clone() + rhs; + } +} + +impl std::ops::Div for OwnedValue { + type Output = OwnedValue; + + fn div(self, rhs: OwnedValue) -> Self::Output { + match (self, rhs) { + (OwnedValue::Integer(int_left), OwnedValue::Integer(int_right)) => { + OwnedValue::Integer(int_left / int_right) + } + (OwnedValue::Integer(int_left), OwnedValue::Float(float_right)) => { + OwnedValue::Float(int_left as f64 / float_right) + } + (OwnedValue::Float(float_left), OwnedValue::Integer(int_right)) => { + OwnedValue::Float(float_left / int_right as f64) + } + (OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => { + OwnedValue::Float(float_left / float_right) + } + _ => unreachable!(), + } + } +} + +impl std::ops::DivAssign for OwnedValue { + fn div_assign(&mut self, rhs: OwnedValue) { + *self = self.clone() / rhs; + } } pub fn to_value(value: &OwnedValue) -> Value<'_> { @@ -34,7 +127,8 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> { OwnedValue::Text(s) => Value::Text(s), OwnedValue::Blob(b) => Value::Blob(b), OwnedValue::Agg(a) => match a.as_ref() { - AggContext::Avg(acc, _count) => Value::Float(*acc), // we assume aggfinal was called + AggContext::Avg(acc, _count) => to_value(acc), // we assume aggfinal was called + AggContext::Sum(acc) => to_value(acc), _ => todo!(), }, } diff --git a/core/vdbe.rs b/core/vdbe.rs index 4c04c962f..5a4f68ab1 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -114,12 +114,14 @@ pub enum Insn { pub enum AggFunc { Avg, + Sum, } impl AggFunc { fn to_string(&self) -> &str { match self { AggFunc::Avg => "avg", + AggFunc::Sum => "sum", _ => "unknown", } } @@ -367,8 +369,10 @@ impl Program { }, Insn::AggStep { acc_reg, col, func } => { if let OwnedValue::Null = &state.registers[*acc_reg] { - state.registers[*acc_reg] = - OwnedValue::Agg(Box::new(AggContext::Avg(0.0, 0))); + state.registers[*acc_reg] = OwnedValue::Agg(Box::new(AggContext::Avg( + OwnedValue::Float(0.0), + OwnedValue::Integer(0), + ))); } match func { AggFunc::Avg => { @@ -377,12 +381,22 @@ impl Program { else { unreachable!(); }; - let AggContext::Avg(acc, count) = agg.borrow_mut(); - match col { - OwnedValue::Integer(i) => *acc += i as f64, - OwnedValue::Float(f) => *acc += f, - _ => unreachable!(), - } + let AggContext::Avg(acc, count) = agg.borrow_mut() else { + unreachable!(); + }; + *acc += col; + *count += 1; + } + AggFunc::Sum => { + let col = state.registers[*col].clone(); + let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() + else { + unreachable!(); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut() else { + unreachable!(); + }; + *acc += col; *count += 1; } }; @@ -395,9 +409,12 @@ impl Program { else { unreachable!(); }; - let AggContext::Avg(acc, count) = agg.borrow_mut(); - *acc /= *count as f64 + let AggContext::Avg(acc, count) = agg.borrow_mut() else { + unreachable!(); + }; + *acc /= count.clone(); } + AggFunc::Sum => {} }; state.pc += 1; }