diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 3d4f08df5..a3a71e2c5 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -237,8 +237,10 @@ pub fn translate_condition_expr( )?; } ast::Expr::Binary(lhs, op, rhs) => { - let lhs_reg = translate_and_mark(program, Some(referenced_tables), lhs, resolver)?; - let rhs_reg = translate_and_mark(program, Some(referenced_tables), rhs, resolver)?; + let lhs_reg = program.alloc_register(); + let rhs_reg = program.alloc_register(); + translate_and_mark(program, Some(referenced_tables), lhs, lhs_reg, resolver)?; + translate_and_mark(program, Some(referenced_tables), rhs, rhs_reg, resolver)?; match op { ast::Operator::Greater => { emit_cmp_insn!(program, condition_metadata, Gt, Le, lhs_reg, rhs_reg) @@ -417,16 +419,17 @@ pub fn translate_condition_expr( let cur_reg = program.alloc_register(); match op { ast::LikeOperator::Like | ast::LikeOperator::Glob => { - let pattern_reg = program.alloc_register(); + let start_reg = program.alloc_registers(2); let mut constant_mask = 0; - let _ = translate_and_mark(program, Some(referenced_tables), lhs, resolver); - let _ = translate_expr( + translate_and_mark( program, Some(referenced_tables), - rhs, - pattern_reg, + lhs, + start_reg + 1, resolver, )?; + let _ = + translate_expr(program, Some(referenced_tables), rhs, start_reg, resolver)?; if matches!(rhs.as_ref(), ast::Expr::Literal(_)) { program.mark_last_insn_constant(); constant_mask = 1; @@ -438,7 +441,7 @@ pub fn translate_condition_expr( }; program.emit_insn(Insn::Function { constant_mask, - start_reg: pattern_reg, + start_reg, dest: cur_reg, func: FuncCtx { func: Func::Scalar(func), @@ -1003,16 +1006,23 @@ pub fn translate_expr( ) } JsonFunc::JsonRemove => { + let start_reg = + program.alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1)); if let Some(args) = args { - for arg in args.iter() { + for (i, arg) in args.iter().enumerate() { // register containing result of each argument expression - let _ = - translate_and_mark(program, referenced_tables, arg, resolver)?; + translate_and_mark( + program, + referenced_tables, + arg, + start_reg + i, + resolver, + )?; } } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1337,11 +1347,17 @@ pub fn translate_expr( | ScalarFunc::Soundex | ScalarFunc::ZeroBlob => { let args = expect_arguments_exact!(args, 1, srf); - let reg = - translate_and_mark(program, referenced_tables, &args[0], resolver)?; + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: reg, + start_reg, dest: target_register, func: func_ctx, }); @@ -1350,11 +1366,17 @@ pub fn translate_expr( #[cfg(not(target_family = "wasm"))] ScalarFunc::LoadExtension => { let args = expect_arguments_exact!(args, 1, srf); - let reg = - translate_and_mark(program, referenced_tables, &args[0], resolver)?; + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: reg, + start_reg, dest: target_register, func: func_ctx, }); @@ -1377,20 +1399,23 @@ pub fn translate_expr( Ok(target_register) } ScalarFunc::Date | ScalarFunc::DateTime => { + let start_reg = program + .alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1)); if let Some(args) = args { - for arg in args.iter() { + for (i, arg) in args.iter().enumerate() { // register containing result of each argument expression - let _ = translate_and_mark( + translate_and_mark( program, referenced_tables, arg, + start_reg + i, resolver, )?; } } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1457,11 +1482,17 @@ pub fn translate_expr( } else { crate::bail_parse_error!("hex function with no arguments",); }; - let regs = - translate_and_mark(program, referenced_tables, &args[0], resolver)?; + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: regs, + start_reg, dest: target_register, func: func_ctx, }); @@ -1495,20 +1526,23 @@ pub fn translate_expr( Ok(target_register) } ScalarFunc::Time => { + let start_reg = program + .alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1)); if let Some(args) = args { - for arg in args.iter() { + for (i, arg) in args.iter().enumerate() { // register containing result of each argument expression - let _ = translate_and_mark( + translate_and_mark( program, referenced_tables, arg, + start_reg + i, resolver, )?; } } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1537,12 +1571,19 @@ pub fn translate_expr( | ScalarFunc::Unhex => { let args = expect_arguments_max!(args, 2, srf); - for arg in args.iter() { - translate_and_mark(program, referenced_tables, arg, resolver)?; + let start_reg = program.alloc_registers(args.len()); + for (i, arg) in args.iter().enumerate() { + translate_and_mark( + program, + referenced_tables, + arg, + start_reg + i, + resolver, + )?; } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1559,13 +1600,20 @@ pub fn translate_expr( } else { crate::bail_parse_error!("min function with no arguments"); }; - for arg in args { - translate_and_mark(program, referenced_tables, arg, resolver)?; + let start_reg = program.alloc_registers(args.len()); + for (i, arg) in args.iter().enumerate() { + translate_and_mark( + program, + referenced_tables, + arg, + start_reg + i, + resolver, + )?; } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1582,13 +1630,20 @@ pub fn translate_expr( } else { crate::bail_parse_error!("max function with no arguments"); }; - for arg in args { - translate_and_mark(program, referenced_tables, arg, resolver)?; + let start_reg = program.alloc_registers(args.len()); + for (i, arg) in args.iter().enumerate() { + translate_and_mark( + program, + referenced_tables, + arg, + start_reg + i, + resolver, + )?; } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1725,20 +1780,23 @@ pub fn translate_expr( Ok(target_register) } ScalarFunc::StrfTime => { + let start_reg = program + .alloc_registers(args.as_ref().map(|x| x.len()).unwrap_or(1)); if let Some(args) = args { - for arg in args.iter() { + for (i, arg) in args.iter().enumerate() { // register containing result of each argument expression - let _ = translate_and_mark( + translate_and_mark( program, referenced_tables, arg, + start_reg + i, resolver, )?; } } program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: target_register + 1, + start_reg, dest: target_register, func: func_ctx, }); @@ -1771,11 +1829,17 @@ pub fn translate_expr( MathFuncArity::Unary => { let args = expect_arguments_exact!(args, 1, math_func); - let reg = - translate_and_mark(program, referenced_tables, &args[0], resolver)?; + let start_reg = program.alloc_register(); + translate_and_mark( + program, + referenced_tables, + &args[0], + start_reg, + resolver, + )?; program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: reg, + start_reg, dest: target_register, func: func_ctx, }); @@ -2165,14 +2229,14 @@ pub fn translate_and_mark( program: &mut ProgramBuilder, referenced_tables: Option<&[TableReference]>, expr: &ast::Expr, + target_register: usize, resolver: &Resolver, -) -> Result { - let target_register = program.alloc_register(); +) -> Result<()> { translate_expr(program, referenced_tables, expr, target_register, resolver)?; if matches!(expr, ast::Expr::Literal(_)) { program.mark_last_insn_constant(); } - Ok(target_register) + Ok(()) } /// Sanitaizes a string literal by removing single quote at front and back diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index e4f130c21..a49d11dc9 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2146,19 +2146,31 @@ impl Program { } ScalarFunc::Trim => { let reg_value = state.registers[*start_reg].clone(); - let pattern_value = state.registers.get(*start_reg + 1).cloned(); + let pattern_value = if func.arg_count == 2 { + state.registers.get(*start_reg + 1).cloned() + } else { + None + }; let result = exec_trim(®_value, pattern_value); state.registers[*dest] = result; } ScalarFunc::LTrim => { let reg_value = state.registers[*start_reg].clone(); - let pattern_value = state.registers.get(*start_reg + 1).cloned(); + let pattern_value = if func.arg_count == 2 { + state.registers.get(*start_reg + 1).cloned() + } else { + None + }; let result = exec_ltrim(®_value, pattern_value); state.registers[*dest] = result; } ScalarFunc::RTrim => { let reg_value = state.registers[*start_reg].clone(); - let pattern_value = state.registers.get(*start_reg + 1).cloned(); + let pattern_value = if func.arg_count == 2 { + state.registers.get(*start_reg + 1).cloned() + } else { + None + }; let result = exec_rtrim(®_value, pattern_value); state.registers[*dest] = result; } @@ -3091,6 +3103,9 @@ fn exec_quote(value: &OwnedValue) -> OwnedValue { for c in s.as_str().chars() { if c == '\0' { break; + } else if c == '\'' { + quoted.push('\''); + quoted.push(c); } else { quoted.push(c); } @@ -3823,7 +3838,7 @@ mod tests { assert_eq!(exec_quote(&input), expected); let input = OwnedValue::build_text(Rc::new(String::from("hello''world"))); - let expected = OwnedValue::build_text(Rc::new(String::from("'hello''world'"))); + let expected = OwnedValue::build_text(Rc::new(String::from("'hello''''world'"))); assert_eq!(exec_quote(&input), expected); } diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index 488898b94..47cca09d4 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -631,6 +631,10 @@ do_execsql_test quote-string { SELECT quote('limbo') } {'limbo'} +do_execsql_test quote-escape { + SELECT quote('''quote''') +} {'''quote'''} + do_execsql_test quote-null { SELECT quote(null) } {NULL} diff --git a/tests/integration/fuzz/mod.rs b/tests/integration/fuzz/mod.rs index a8dcd1d6f..9c5b00298 100644 --- a/tests/integration/fuzz/mod.rs +++ b/tests/integration/fuzz/mod.rs @@ -177,6 +177,113 @@ mod tests { } } + #[test] + pub fn string_expression_fuzz_run() { + let _ = env_logger::try_init(); + let g = GrammarGenerator::new(); + let (expr, expr_builder) = g.create_handle(); + let (bin_op, bin_op_builder) = g.create_handle(); + let (scalar, scalar_builder) = g.create_handle(); + let (paren, paren_builder) = g.create_handle(); + + paren_builder + .concat("") + .push_str("(") + .push(expr) + .push_str(")") + .build(); + + bin_op_builder + .concat(" ") + .push(expr) + .push(g.create().choice().options_str(["||"]).build()) + .push(expr) + .build(); + + scalar_builder + .choice() + .option( + g.create() + .concat("") + .push_str("char(") + .push( + g.create() + .concat("") + .push_symbol(rand_int(65..91)) + .repeat(1..8, ", ") + .build(), + ) + .push_str(")") + .build(), + ) + .option( + g.create() + .concat("") + .push( + g.create() + .choice() + .options_str(["ltrim", "rtrim", "trim"]) + .build(), + ) + .push_str("(") + .push(g.create().concat("").push(expr).repeat(2..3, ", ").build()) + .push_str(")") + .build(), + ) + .option( + g.create() + .concat("") + .push( + g.create() + .choice() + .options_str([ + "ltrim", "rtrim", "lower", "upper", "quote", "hex", "trim", + ]) + .build(), + ) + .push_str("(") + .push(expr) + .push_str(")") + .build(), + ) + .build(); + + expr_builder + .choice() + .option_w(bin_op, 1.0) + .option_w(paren, 1.0) + .option_w(scalar, 1.0) + .option( + g.create() + .concat("") + .push_str("'") + .push_symbol(rand_str("", 2)) + .push_str("'") + .build(), + ) + .build(); + + let sql = g.create().concat(" ").push_str("SELECT").push(expr).build(); + + let db = TempDatabase::new_empty(); + let limbo_conn = db.connect_limbo(); + let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap(); + + let (mut rng, seed) = rng_from_time(); + log::info!("seed: {}", seed); + for _ in 0..1024 { + let query = g.generate(&mut rng, sql, 50); + log::info!("query: {}", query); + let limbo = limbo_exec_row(&limbo_conn, &query); + let sqlite = sqlite_exec_row(&sqlite_conn, &query); + assert_eq!( + limbo, sqlite, + "query: {}, limbo: {:?}, sqlite: {:?}", + query, limbo, sqlite + ); + } + } + #[test] pub fn logical_expression_fuzz_run() { let _ = env_logger::try_init();