diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs index ae93a4b05..44b2cef49 100644 --- a/core/incremental/expr_compiler.rs +++ b/core/incremental/expr_compiler.rs @@ -92,14 +92,14 @@ pub enum ExpressionExecutor { } /// Trivial expression that can be evaluated inline without VDBE -/// Only supports operations where operands have the same type (no coercion) +/// Supports arithmetic operations with automatic type promotion (integer to float) #[derive(Clone, Debug)] pub enum TrivialExpression { /// Direct column reference Column(usize), /// Immediate value Immediate(Value), - /// Binary operation on trivial expressions (same-type operands only) + /// Binary operation on trivial expressions (supports type promotion) Binary { left: Box, op: Operator, @@ -109,7 +109,7 @@ pub enum TrivialExpression { impl TrivialExpression { /// Evaluate the trivial expression with the given input values - /// Panics if type mismatch occurs (this indicates a bug in validation) + /// Automatically promotes integers to floats when mixing types in arithmetic pub fn evaluate(&self, values: &[Value]) -> Value { match self { TrivialExpression::Column(idx) => values.get(*idx).cloned().unwrap_or(Value::Null), @@ -118,23 +118,32 @@ impl TrivialExpression { let left_val = left.evaluate(values); let right_val = right.evaluate(values); - // Only perform operations on same-type operands + // Perform operations with type promotion when needed match op { Operator::Add => match (&left_val, &right_val) { (Value::Integer(a), Value::Integer(b)) => Value::Integer(a + b), (Value::Float(a), Value::Float(b)) => Value::Float(a + b), + // Mixed integer/float - promote integer to float + (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 + b), + (Value::Float(a), Value::Integer(b)) => Value::Float(a + *b as f64), (Value::Null, _) | (_, Value::Null) => Value::Null, _ => panic!("Type mismatch in trivial expression: {left_val:?} + {right_val:?}. This is a bug in trivial expression validation."), }, Operator::Subtract => match (&left_val, &right_val) { (Value::Integer(a), Value::Integer(b)) => Value::Integer(a - b), (Value::Float(a), Value::Float(b)) => Value::Float(a - b), + // Mixed integer/float - promote integer to float + (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 - b), + (Value::Float(a), Value::Integer(b)) => Value::Float(a - *b as f64), (Value::Null, _) | (_, Value::Null) => Value::Null, _ => panic!("Type mismatch in trivial expression: {left_val:?} - {right_val:?}. This is a bug in trivial expression validation."), }, Operator::Multiply => match (&left_val, &right_val) { (Value::Integer(a), Value::Integer(b)) => Value::Integer(a * b), (Value::Float(a), Value::Float(b)) => Value::Float(a * b), + // Mixed integer/float - promote integer to float + (Value::Integer(a), Value::Float(b)) => Value::Float(*a as f64 * b), + (Value::Float(a), Value::Integer(b)) => Value::Float(a * *b as f64), (Value::Null, _) | (_, Value::Null) => Value::Null, _ => panic!("Type mismatch in trivial expression: {left_val:?} * {right_val:?}. This is a bug in trivial expression validation."), }, @@ -153,6 +162,21 @@ impl TrivialExpression { Value::Null } } + // Mixed integer/float - promote integer to float + (Value::Integer(a), Value::Float(b)) => { + if *b != 0.0 { + Value::Float(*a as f64 / b) + } else { + Value::Null + } + } + (Value::Float(a), Value::Integer(b)) => { + if *b != 0 { + Value::Float(a / *b as f64) + } else { + Value::Null + } + } (Value::Null, _) | (_, Value::Null) => Value::Null, _ => panic!("Type mismatch in trivial expression: {left_val:?} / {right_val:?}. This is a bug in trivial expression validation."), }, @@ -266,23 +290,27 @@ impl CompiledExpression { let right_trivial = Self::try_get_trivial_expr(right, input_column_names)?; // Check if we can determine types statically - // If both are immediates, they must have the same type - // If either is a column, we can't validate at compile time, - // but we'll assert at runtime if there's a mismatch + // For arithmetic operations, we allow mixing integers and floats + // since we promote integers to floats as needed if let (Some(left_type), Some(right_type)) = ( Self::get_trivial_type(&left_trivial), Self::get_trivial_type(&right_trivial), ) { - // Both types are known - they must match (or one is null) - if left_type != right_type - && left_type != TrivialType::Null - && right_type != TrivialType::Null - { - return None; // Type mismatch - not trivial + // Both types are known - check if they're numeric or null + let numeric_types = matches!( + left_type, + TrivialType::Integer | TrivialType::Float | TrivialType::Null + ) && matches!( + right_type, + TrivialType::Integer | TrivialType::Float | TrivialType::Null + ); + + if !numeric_types { + return None; // Non-numeric types - not trivial } } // If we can't determine types (columns involved), we optimistically - // assume they'll match at runtime (and assert if they don't) + // assume they'll be compatible at runtime Some(TrivialExpression::Binary { left: Box::new(left_trivial), @@ -450,3 +478,77 @@ impl CompiledExpression { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mixed_type_arithmetic() { + // Test integer - float + let expr = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Integer(1))), + op: Operator::Subtract, + right: Box::new(TrivialExpression::Immediate(Value::Float(0.5))), + }; + let result = expr.evaluate(&[]); + assert_eq!(result, Value::Float(0.5)); + + // Test float - integer + let expr = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Float(2.5))), + op: Operator::Subtract, + right: Box::new(TrivialExpression::Immediate(Value::Integer(1))), + }; + let result = expr.evaluate(&[]); + assert_eq!(result, Value::Float(1.5)); + + // Test integer * float + let expr = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Integer(10))), + op: Operator::Multiply, + right: Box::new(TrivialExpression::Immediate(Value::Float(0.1))), + }; + let result = expr.evaluate(&[]); + assert_eq!(result, Value::Float(1.0)); + + // Test integer / float + let expr = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Integer(1))), + op: Operator::Divide, + right: Box::new(TrivialExpression::Immediate(Value::Float(2.0))), + }; + let result = expr.evaluate(&[]); + assert_eq!(result, Value::Float(0.5)); + + // Test integer + float + let expr = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Integer(1))), + op: Operator::Add, + right: Box::new(TrivialExpression::Immediate(Value::Float(0.5))), + }; + let result = expr.evaluate(&[]); + assert_eq!(result, Value::Float(1.5)); + } + + #[test] + fn test_nested_mixed_type_expressions() { + // Test nested expressions with mixed types: (1 - 0.04) + let one_minus_float = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Integer(1))), + op: Operator::Subtract, + right: Box::new(TrivialExpression::Immediate(Value::Float(0.04))), + }; + let result = one_minus_float.evaluate(&[]); + assert_eq!(result, Value::Float(0.96)); + + // Test multiplication with nested mixed-type expression: 100.0 * (1 - 0.04) + let nested_expr = TrivialExpression::Binary { + left: Box::new(TrivialExpression::Immediate(Value::Float(100.0))), + op: Operator::Multiply, + right: Box::new(one_minus_float), + }; + let result = nested_expr.evaluate(&[]); + assert_eq!(result, Value::Float(96.0)); + } +}