diff --git a/core/types.rs b/core/types.rs index bccd647c0..a5ddaa906 100644 --- a/core/types.rs +++ b/core/types.rs @@ -401,14 +401,6 @@ pub struct ExternalAggState { pub argc: usize, pub step_fn: StepFunction, pub finalize_fn: FinalizeFunction, - pub finalized_value: Option, -} - -impl ExternalAggState { - pub fn cache_final_value(&mut self, value: Value) -> &Value { - self.finalized_value = Some(value); - self.finalized_value.as_ref().unwrap() - } } /// Please use Display trait for all limbo output so we have single origin of truth @@ -634,14 +626,13 @@ pub enum AggContext { } impl AggContext { - pub fn compute_external(&mut self) -> Result<()> { + pub fn compute_external(&self) -> Result { if let Self::External(ext_state) = self { - if ext_state.finalized_value.is_none() { - let final_value = unsafe { (ext_state.finalize_fn)(ext_state.state) }; - ext_state.cache_final_value(Value::from_ffi(final_value)?); - } + let final_value = unsafe { (ext_state.finalize_fn)(ext_state.state) }; + Value::from_ffi(final_value) + } else { + panic!("AggContext::compute_external() expected External, found {self:?}"); } - Ok(()) } } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 280caa520..41a3baed3 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3468,7 +3468,6 @@ pub fn op_agg_step( argc: *argc, step_fn: *step, finalize_fn: *finalize, - finalized_value: None, })), _ => unreachable!("scalar function called in aggregate context"), }, @@ -3736,17 +3735,17 @@ pub fn op_agg_final( mv_store: Option<&Arc>, ) -> Result { load_insn!(AggFinal { register, func }, insn); - match state.registers[*register].borrow_mut() { + match &state.registers[*register] { Register::Aggregate(agg) => match func { AggFunc::Avg => { - let AggContext::Avg(acc, count) = agg.borrow_mut() else { + let AggContext::Avg(acc, count) = agg else { unreachable!(); }; - *acc /= count.clone(); - state.registers[*register] = Register::Value(acc.clone()); + let acc = acc.clone() / count.clone(); + state.registers[*register] = Register::Value(acc); } AggFunc::Sum => { - let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else { + let AggContext::Sum(acc, sum_state) = agg else { unreachable!(); }; let value = match acc { @@ -3762,7 +3761,7 @@ pub fn op_agg_final( state.registers[*register] = Register::Value(value); } AggFunc::Total => { - let AggContext::Sum(acc, _) = agg.borrow_mut() else { + let AggContext::Sum(acc, _) = agg else { unreachable!(); }; let value = match acc { @@ -3774,13 +3773,13 @@ pub fn op_agg_final( state.registers[*register] = Register::Value(value); } AggFunc::Count | AggFunc::Count0 => { - let AggContext::Count(count) = agg.borrow_mut() else { + let AggContext::Count(count) = agg else { unreachable!(); }; state.registers[*register] = Register::Value(count.clone()); } AggFunc::Max => { - let AggContext::Max(acc) = agg.borrow_mut() else { + let AggContext::Max(acc) = agg else { unreachable!(); }; match acc { @@ -3789,7 +3788,7 @@ pub fn op_agg_final( } } AggFunc::Min => { - let AggContext::Min(acc) = agg.borrow_mut() else { + let AggContext::Min(acc) = agg else { unreachable!(); }; match acc { @@ -3798,14 +3797,14 @@ pub fn op_agg_final( } } AggFunc::GroupConcat | AggFunc::StringAgg => { - let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + let AggContext::GroupConcat(acc) = agg else { unreachable!(); }; state.registers[*register] = Register::Value(acc.clone()); } #[cfg(feature = "json")] AggFunc::JsonGroupObject => { - let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + let AggContext::GroupConcat(acc) = agg else { unreachable!(); }; let data = acc.to_blob().expect("Should be blob"); @@ -3813,7 +3812,7 @@ pub fn op_agg_final( } #[cfg(feature = "json")] AggFunc::JsonbGroupObject => { - let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + let AggContext::GroupConcat(acc) = agg else { unreachable!(); }; let data = acc.to_blob().expect("Should be blob"); @@ -3821,7 +3820,7 @@ pub fn op_agg_final( } #[cfg(feature = "json")] AggFunc::JsonGroupArray => { - let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + let AggContext::GroupConcat(acc) = agg else { unreachable!(); }; let data = acc.to_blob().expect("Should be blob"); @@ -3829,21 +3828,18 @@ pub fn op_agg_final( } #[cfg(feature = "json")] AggFunc::JsonbGroupArray => { - let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + let AggContext::GroupConcat(acc) = agg else { unreachable!(); }; let data = acc.to_blob().expect("Should be blob"); state.registers[*register] = Register::Value(json_from_raw_bytes_agg(data, true)?); } AggFunc::External(_) => { - agg.compute_external()?; let AggContext::External(agg_state) = agg else { unreachable!(); }; - match &agg_state.finalized_value { - Some(value) => state.registers[*register] = Register::Value(value.clone()), - None => state.registers[*register] = Register::Value(Value::Null), - } + let value = agg.compute_external()?; + state.registers[*register] = Register::Value(value) } }, Register::Value(Value::Null) => {