From 22d096acc684030ad6bfcb52c822ab4400f5769f Mon Sep 17 00:00:00 2001 From: C4 Patino Date: Sun, 17 Aug 2025 21:42:49 -0500 Subject: [PATCH 01/73] fix(flake.nix): added uv dependency to flake.nix --- flake.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/flake.nix b/flake.nix index e262c09c4..8f59224a9 100644 --- a/flake.nix +++ b/flake.nix @@ -71,6 +71,7 @@ python3 nodejs toolchain + uv ] ++ lib.optionals pkgs.stdenv.isDarwin [ apple-sdk ]; From 5ed2abf23f7793b50e2eb8c2ebc9ee9a66bb0635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Francoeur?= Date: Thu, 21 Aug 2025 10:50:33 -0400 Subject: [PATCH 02/73] remove Result from signature --- core/json/jsonb.rs | 135 ++++++++++++++++++++++----------------------- core/json/mod.rs | 8 +-- 2 files changed, 70 insertions(+), 73 deletions(-) diff --git a/core/json/jsonb.rs b/core/json/jsonb.rs index cb4028abd..c4c95aac0 100644 --- a/core/json/jsonb.rs +++ b/core/json/jsonb.rs @@ -909,12 +909,13 @@ impl Jsonb { } } - pub fn to_string(&self) -> Result { + #[expect(clippy::inherent_to_string)] + pub fn to_string(&self) -> String { let mut result = String::with_capacity(self.data.len() * 2); - self.write_to_string(&mut result, JsonIndentation::None)?; + self.write_to_string(&mut result, JsonIndentation::None); - Ok(result) + result } pub fn to_string_pretty(&self, indentation: Option<&str>) -> Result { @@ -924,16 +925,15 @@ impl Jsonb { } else { JsonIndentation::Indentation(Cow::Borrowed(" ")) }; - self.write_to_string(&mut result, ind)?; + self.write_to_string(&mut result, ind); Ok(result) } - fn write_to_string(&self, string: &mut String, indentation: JsonIndentation) -> Result<()> { + fn write_to_string(&self, string: &mut String, indentation: JsonIndentation) { let cursor = 0; let ind = indentation; let _ = self.serialize_value(string, cursor, 0, &ind); - Ok(()) } fn serialize_value( @@ -3093,7 +3093,7 @@ mod tests { jsonb.data.push(ElementType::NULL as u8); // Test serialization - let json_str = jsonb.to_string().unwrap(); + let json_str = jsonb.to_string(); assert_eq!(json_str, "null"); // Test round-trip @@ -3106,12 +3106,12 @@ mod tests { // True let mut jsonb_true = Jsonb::new(10, None); jsonb_true.data.push(ElementType::TRUE as u8); - assert_eq!(jsonb_true.to_string().unwrap(), "true"); + assert_eq!(jsonb_true.to_string(), "true"); // False let mut jsonb_false = Jsonb::new(10, None); jsonb_false.data.push(ElementType::FALSE as u8); - assert_eq!(jsonb_false.to_string().unwrap(), "false"); + assert_eq!(jsonb_false.to_string(), "false"); // Round-trip let true_parsed = Jsonb::from_str("true").unwrap(); @@ -3125,15 +3125,15 @@ mod tests { fn test_integer_serialization() { // Standard integer let parsed = Jsonb::from_str("42").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "42"); + assert_eq!(parsed.to_string(), "42"); // Negative integer let parsed = Jsonb::from_str("-123").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-123"); + assert_eq!(parsed.to_string(), "-123"); // Zero let parsed = Jsonb::from_str("0").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "0"); + assert_eq!(parsed.to_string(), "0"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3144,15 +3144,15 @@ mod tests { fn test_json5_integer_serialization() { // Hexadecimal notation let parsed = Jsonb::from_str("0x1A").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "26"); // Should convert to decimal + assert_eq!(parsed.to_string(), "26"); // Should convert to decimal // Positive sign (JSON5) let parsed = Jsonb::from_str("+42").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "42"); + assert_eq!(parsed.to_string(), "42"); // Negative hexadecimal let parsed = Jsonb::from_str("-0xFF").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-255"); + assert_eq!(parsed.to_string(), "-255"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3163,15 +3163,15 @@ mod tests { fn test_float_serialization() { // Standard float let parsed = Jsonb::from_str("3.14159").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "3.14159"); + assert_eq!(parsed.to_string(), "3.14159"); // Negative float let parsed = Jsonb::from_str("-2.718").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-2.718"); + assert_eq!(parsed.to_string(), "-2.718"); // Scientific notation let parsed = Jsonb::from_str("6.022e23").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "6.022e23"); + assert_eq!(parsed.to_string(), "6.022e23"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3182,23 +3182,23 @@ mod tests { fn test_json5_float_serialization() { // Leading decimal point let parsed = Jsonb::from_str(".123").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "0.123"); + assert_eq!(parsed.to_string(), "0.123"); // Trailing decimal point let parsed = Jsonb::from_str("42.").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "42.0"); + assert_eq!(parsed.to_string(), "42.0"); // Plus sign in exponent let parsed = Jsonb::from_str("1.5e+10").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "1.5e+10"); + assert_eq!(parsed.to_string(), "1.5e+10"); // Infinity let parsed = Jsonb::from_str("Infinity").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "9e999"); + assert_eq!(parsed.to_string(), "9e999"); // Negative Infinity let parsed = Jsonb::from_str("-Infinity").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "-9e999"); + assert_eq!(parsed.to_string(), "-9e999"); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3209,15 +3209,15 @@ mod tests { fn test_string_serialization() { // Simple string let parsed = Jsonb::from_str(r#""hello world""#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + assert_eq!(parsed.to_string(), r#""hello world""#); // String with escaped characters let parsed = Jsonb::from_str(r#""hello\nworld""#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello\nworld""#); + assert_eq!(parsed.to_string(), r#""hello\nworld""#); // Unicode escape let parsed = Jsonb::from_str(r#""hello\u0020world""#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello\u0020world""#); + assert_eq!(parsed.to_string(), r#""hello\u0020world""#); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3228,11 +3228,11 @@ mod tests { fn test_json5_string_serialization() { // Single quotes let parsed = Jsonb::from_str("'hello world'").unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + assert_eq!(parsed.to_string(), r#""hello world""#); // Hex escape let parsed = Jsonb::from_str(r#"'\x41\x42\x43'"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""\u0041\u0042\u0043""#); + assert_eq!(parsed.to_string(), r#""\u0041\u0042\u0043""#); // Multiline string with line continuation let parsed = Jsonb::from_str( @@ -3240,11 +3240,11 @@ mod tests { world""#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""hello world""#); + assert_eq!(parsed.to_string(), r#""hello world""#); // Escaped single quote let parsed = Jsonb::from_str(r#"'Don\'t worry'"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#""Don't worry""#); + assert_eq!(parsed.to_string(), r#""Don't worry""#); // Verify correct type let header = JsonbHeader::from_slice(0, &parsed.data).unwrap().0; @@ -3255,20 +3255,20 @@ world""#, fn test_array_serialization() { // Empty array let parsed = Jsonb::from_str("[]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[]"); + assert_eq!(parsed.to_string(), "[]"); // Simple array let parsed = Jsonb::from_str("[1,2,3]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); // Nested array let parsed = Jsonb::from_str("[[1,2],[3,4]]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[[1,2],[3,4]]"); + assert_eq!(parsed.to_string(), "[[1,2],[3,4]]"); // Mixed types array let parsed = Jsonb::from_str(r#"[1,"text",true,null,{"key":"value"}]"#).unwrap(); assert_eq!( - parsed.to_string().unwrap(), + parsed.to_string(), r#"[1,"text",true,null,{"key":"value"}]"# ); @@ -3281,44 +3281,41 @@ world""#, fn test_json5_array_serialization() { // Trailing comma let parsed = Jsonb::from_str("[1,2,3,]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); // Comments in array let parsed = Jsonb::from_str("[1,/* comment */2,3]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); // Line comment in array let parsed = Jsonb::from_str("[1,// line comment\n2,3]").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); } #[test] fn test_object_serialization() { // Empty object let parsed = Jsonb::from_str("{}").unwrap(); - assert_eq!(parsed.to_string().unwrap(), "{}"); + assert_eq!(parsed.to_string(), "{}"); // Simple object let parsed = Jsonb::from_str(r#"{"key":"value"}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Multiple properties let parsed = Jsonb::from_str(r#"{"a":1,"b":2,"c":3}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2,"c":3}"#); + assert_eq!(parsed.to_string(), r#"{"a":1,"b":2,"c":3}"#); // Nested object let parsed = Jsonb::from_str(r#"{"outer":{"inner":"value"}}"#).unwrap(); - assert_eq!( - parsed.to_string().unwrap(), - r#"{"outer":{"inner":"value"}}"# - ); + assert_eq!(parsed.to_string(), r#"{"outer":{"inner":"value"}}"#); // Mixed values let parsed = Jsonb::from_str(r#"{"str":"text","num":42,"bool":true,"null":null,"arr":[1,2]}"#) .unwrap(); assert_eq!( - parsed.to_string().unwrap(), + parsed.to_string(), r#"{"str":"text","num":42,"bool":true,"null":null,"arr":[1,2]}"# ); @@ -3331,19 +3328,19 @@ world""#, fn test_json5_object_serialization() { // Unquoted keys let parsed = Jsonb::from_str("{key:\"value\"}").unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Trailing comma let parsed = Jsonb::from_str(r#"{"a":1,"b":2,}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2}"#); + assert_eq!(parsed.to_string(), r#"{"a":1,"b":2}"#); // Comments in object let parsed = Jsonb::from_str(r#"{"a":1,/*comment*/"b":2}"#).unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":1,"b":2}"#); + assert_eq!(parsed.to_string(), r#"{"a":1,"b":2}"#); // Single quotes for keys and values let parsed = Jsonb::from_str("{'a':'value'}").unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"a":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"a":"value"}"#); } #[test] @@ -3366,8 +3363,8 @@ world""#, let parsed = Jsonb::from_str(complex_json).unwrap(); // Round-trip test - let reparsed = Jsonb::from_str(&parsed.to_string().unwrap()).unwrap(); - assert_eq!(parsed.to_string().unwrap(), reparsed.to_string().unwrap()); + let reparsed = Jsonb::from_str(&parsed.to_string()).unwrap(); + assert_eq!(parsed.to_string(), reparsed.to_string()); } #[test] @@ -3486,11 +3483,11 @@ world""#, fn test_unicode_escapes() { // Basic unicode escape let parsed = Jsonb::from_str(r#""\u00A9""#).unwrap(); // Copyright symbol - assert_eq!(parsed.to_string().unwrap(), r#""\u00A9""#); + assert_eq!(parsed.to_string(), r#""\u00A9""#); // Non-BMP character (surrogate pair) let parsed = Jsonb::from_str(r#""\uD83D\uDE00""#).unwrap(); // Smiley emoji - assert_eq!(parsed.to_string().unwrap(), r#""\uD83D\uDE00""#); + assert_eq!(parsed.to_string(), r#""\uD83D\uDE00""#); } #[test] @@ -3503,7 +3500,7 @@ world""#, }"#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Block comments let parsed = Jsonb::from_str( @@ -3514,7 +3511,7 @@ world""#, }"#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), r#"{"key":"value"}"#); + assert_eq!(parsed.to_string(), r#"{"key":"value"}"#); // Comments inside array let parsed = Jsonb::from_str( @@ -3522,7 +3519,7 @@ world""#, 2, /* Another comment */ 3]"#, ) .unwrap(); - assert_eq!(parsed.to_string().unwrap(), "[1,2,3]"); + assert_eq!(parsed.to_string(), "[1,2,3]"); } #[test] @@ -3540,7 +3537,7 @@ world""#, let parsed = Jsonb::from_str(json_with_whitespace).unwrap(); assert_eq!( - parsed.to_string().unwrap(), + parsed.to_string(), r#"{"key1":"value1","key2":[1,2,3],"key3":{"nested":true}}"# ); } @@ -3554,7 +3551,7 @@ world""#, // Create a new Jsonb from the binary data let from_binary = Jsonb::new(0, Some(&binary_data)); - assert_eq!(from_binary.to_string().unwrap(), original); + assert_eq!(from_binary.to_string(), original); } #[test] @@ -3570,8 +3567,8 @@ world""#, large_array.push(']'); let parsed = Jsonb::from_str(&large_array).unwrap(); - assert!(parsed.to_string().unwrap().starts_with("[0,1,2,")); - assert!(parsed.to_string().unwrap().ends_with("998,999]")); + assert!(parsed.to_string().starts_with("[0,1,2,")); + assert!(parsed.to_string().ends_with("998,999]")); } #[test] @@ -3600,7 +3597,7 @@ world""#, }"#; let parsed = Jsonb::from_str(json).unwrap(); - let result = parsed.to_string().unwrap(); + let result = parsed.to_string(); assert!(result.contains(r#""escaped_quotes":"He said \"Hello\"""#)); assert!(result.contains(r#""backslashes":"C:\\Windows\\System32""#)); @@ -3767,7 +3764,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the value was updated - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"Jane","age":30}"#); } @@ -3791,7 +3788,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the value was inserted - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John","age":30}"#); } @@ -3814,7 +3811,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the property was deleted - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John"}"#); } @@ -3839,7 +3836,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the value was replaced - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"items":[10,50,30]}"#); } @@ -3863,7 +3860,7 @@ mod path_operations_tests { // Get the search result let search_result = operation.result(); - let result_str = search_result.to_string().unwrap(); + let result_str = search_result.to_string(); // Verify the search found the correct value assert_eq!(result_str, r#"{"name":"John","age":30}"#); @@ -3912,7 +3909,7 @@ mod path_operations_tests { assert!(result.is_ok()); // Verify the deep value was updated - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!( updated_json, r#"{"level1":{"level2":{"level3":{"value":100}}}}"# @@ -3953,7 +3950,7 @@ mod path_operations_tests { let result = jsonb.operate_on_path(&path, &mut operation); assert!(result.is_ok()); - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John","age":30}"#); // 3. InsertNew mode - should fail when path already exists @@ -3991,7 +3988,7 @@ mod path_operations_tests { let result = jsonb.operate_on_path(&path, &mut operation); assert!(result.is_ok()); - let updated_json = jsonb.to_string().unwrap(); + let updated_json = jsonb.to_string(); assert_eq!(updated_json, r#"{"name":"John","age":31,"surname":"Doe"}"#); } } diff --git a/core/json/mod.rs b/core/json/mod.rs index 5e8aeadef..caa1b28a0 100644 --- a/core/json/mod.rs +++ b/core/json/mod.rs @@ -44,7 +44,7 @@ pub fn get_json(json_value: &Value, indent: Option<&str>) -> crate::Result json_val.to_string_pretty(Some(indent))?, - None => json_val.to_string()?, + None => json_val.to_string(), }; Ok(Value::Text(Text::json(json))) @@ -53,7 +53,7 @@ pub fn get_json(json_value: &Value, indent: Option<&str>) -> crate::Result crate::Result { - let mut json_string = json.to_string()?; + let mut json_string = json.to_string(); if matches!(flag, OutputVariant::Binary) { return Ok(Value::Blob(json.data())); } From 1d24925e215d4e57541f7966d48a0e05b6493e0e Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Sat, 23 Aug 2025 00:21:50 +0300 Subject: [PATCH 03/73] Make fill_cell_payload() safe for async IO and cache spilling Problems: 1. fill_cell_payload() is not re-entrant because it can yield IO on allocating a new overflow page, resulting in losing some of the input data. 2. fill_cell_payload() in its current form is not safe for cache spilling because the previous overflow page in the chain of allocated overflow pages can be evicted by a spill caused by the next overflow page allocation, invalidating the page pointer and causing corruption. 3. fill_cell_payload() uses raw pointers and `unsafe` as a workaround from a previous time when we used to clone `WriteState`, resulting in hard-to-read code. Solutions: 1. Introduce a new substate to the fill_cell_payload state machine to handle re-entrancy wrt. allocating overflow pages. 2. Always pin the current overflow page so that it cannot be evicted during the overflow chain construction. Also pin the regular page the overflow chain is attached to, because it is immediately accessed after fill_cell_payload is done. 3. Remove all explicit usages of `unsafe` from `fill_cell_payload` (although our pager is ofc still extremely unsafe under the hood :] ) Note that solution 2 addresses a problem that arose in the development of page cache spilling, which is not yet implemented, but will be soon. Miscellania: 1. Renamed a bunch of variables to be clearer 2. Added more comments about what is happening in fill_cell_payload --- core/storage/btree.rs | 219 +++++++++++++++++++++++++----------------- 1 file changed, 129 insertions(+), 90 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index ca5624220..ab263e406 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -2274,7 +2274,7 @@ impl BTreeCursor { ref mut fill_cell_payload_state, } => { return_if_io!(fill_cell_payload( - page.get().get().contents.as_ref().unwrap(), + page.get(), bkey.maybe_rowid(), new_payload, *cell_idx, @@ -5176,10 +5176,9 @@ impl BTreeCursor { fill_cell_payload_state, } => { let page = page_ref.get(); - let page_contents = page.get().contents.as_ref().unwrap(); { return_if_io!(fill_cell_payload( - page_contents, + page, *rowid, new_payload, cell_idx, @@ -6963,39 +6962,66 @@ fn allocate_cell_space( #[derive(Debug, Clone)] pub enum FillCellPayloadState { + /// Determine whether we can fit the record on the current page. + /// If yes, return immediately after copying the data. + /// Otherwise move to [CopyData] state. Start, - AllocateOverflowPages { - /// Arc because we clone [WriteState] for some reason and we use unsafe pointer dereferences in [FillCellPayloadState::AllocateOverflowPages] - /// so the underlying bytes must not be cloned in upper layers. - record_buf: Arc<[u8]>, - space_left: usize, - to_copy_buffer_ptr: *const u8, - to_copy_buffer_len: usize, - pointer: *mut u8, - pointer_to_next: *mut u8, + /// Copy the next chunk of data from the record buffer to the cell payload. + /// If we can't fit all of the remaining data on the current page, + /// move the internal state to [CopyDataState::AllocateOverflowPage] + CopyData { + /// Internal state of the copy data operation. + /// We can either be copying data or allocating an overflow page. + state: CopyDataState, + /// Track how much space we have left on the current page we are copying data into. + /// This is reset whenever a new overflow page is allocated. + space_left_on_cur_page: usize, + /// Offset into the record buffer to copy from. + src_data_offset: usize, + /// Offset into the destination buffer we are copying data into. + /// This is either: + /// - an offset in the btree page where the cell is, or + /// - an offset in an overflow page + dst_data_offset: usize, + /// If this is Some, we will copy data into this overflow page. + /// If this is None, we will copy data into the cell payload on the btree page. + /// Also: to safely form a chain of overflow pages, the current page must be pinned to the page cache + /// so that e.g. a spilling operation does not evict it to disk. + current_overflow_page: Option, }, } +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum CopyDataState { + /// Copy the next chunk of data from the record buffer to the cell payload. + Copy, + /// Allocate a new overflow page if we couldn't fit all data to the current page. + AllocateOverflowPage, +} + /// Fill in the cell payload with the record. /// If the record is too large to fit in the cell, it will spill onto overflow pages. /// This function needs a separate [FillCellPayloadState] because allocating overflow pages /// may require I/O. #[allow(clippy::too_many_arguments)] fn fill_cell_payload( - page_contents: &PageContent, + page: PageRef, int_key: Option, cell_payload: &mut Vec, cell_idx: usize, record: &ImmutableRecord, usable_space: usize, pager: Rc, - state: &mut FillCellPayloadState, + fill_cell_payload_state: &mut FillCellPayloadState, ) -> Result> { + let overflow_page_pointer_size = 4; + let overflow_page_data_size = usable_space - overflow_page_pointer_size; loop { - match state { + let record_buf = record.get_payload(); + match fill_cell_payload_state { FillCellPayloadState::Start => { - // TODO: make record raw from start, having to serialize is not good - let record_buf: Arc<[u8]> = Arc::from(record.get_payload()); + page.pin(); // We need to pin this page because we will be accessing its contents after fill_cell_payload is done. + let page_contents = page.get().contents.as_ref().unwrap(); let page_type = page_contents.page_type(); // fill in header @@ -7024,25 +7050,25 @@ fn fill_cell_payload( if record_buf.len() <= payload_overflow_threshold_max { // enough allowed space to fit inside a btree page cell_payload.extend_from_slice(record_buf.as_ref()); - return Ok(IOResult::Done(())); + break; } let payload_overflow_threshold_min = payload_overflow_threshold_min(page_type, usable_space); // see e.g. https://github.com/sqlite/sqlite/blob/9591d3fe93936533c8c3b0dc4d025ac999539e11/src/dbstat.c#L371 let mut space_left = payload_overflow_threshold_min - + (record_buf.len() - payload_overflow_threshold_min) % (usable_space - 4); + + (record_buf.len() - payload_overflow_threshold_min) % overflow_page_data_size; if space_left > payload_overflow_threshold_max { space_left = payload_overflow_threshold_min; } // cell_size must be equal to first value of space_left as this will be the bytes copied to non-overflow page. - let cell_size = space_left + cell_payload.len() + 4; // 4 is the number of bytes of pointer to first overflow page - let to_copy_buffer = record_buf.as_ref(); + let cell_size = space_left + cell_payload.len() + overflow_page_pointer_size; let prev_size = cell_payload.len(); - cell_payload.resize(prev_size + space_left + 4, 0); + let new_data_size = prev_size + space_left; + cell_payload.resize(new_data_size + overflow_page_pointer_size, 0); assert_eq!( cell_size, cell_payload.len(), @@ -7051,87 +7077,100 @@ fn fill_cell_payload( cell_payload.len() ); - // SAFETY: this pointer is valid because it points to a buffer in an Arc>> that lives at least as long as this function, - // and the Vec will not be mutated in FillCellPayloadState::AllocateOverflowPages, which we will move to next. - let pointer = unsafe { cell_payload.as_mut_ptr().add(prev_size) }; - let pointer_to_next = - unsafe { cell_payload.as_mut_ptr().add(prev_size + space_left) }; - - let to_copy_buffer_ptr = to_copy_buffer.as_ptr(); - let to_copy_buffer_len = to_copy_buffer.len(); - - *state = FillCellPayloadState::AllocateOverflowPages { - record_buf, - space_left, - to_copy_buffer_ptr, - to_copy_buffer_len, - pointer, - pointer_to_next, + *fill_cell_payload_state = FillCellPayloadState::CopyData { + state: CopyDataState::Copy, + space_left_on_cur_page: space_left, + src_data_offset: 0, + dst_data_offset: prev_size, + current_overflow_page: None, }; continue; } - FillCellPayloadState::AllocateOverflowPages { - record_buf: _record_buf, - space_left, - to_copy_buffer_ptr, - to_copy_buffer_len, - pointer, - pointer_to_next, + FillCellPayloadState::CopyData { + state, + src_data_offset, + space_left_on_cur_page, + dst_data_offset, + current_overflow_page, } => { - let to_copy; - { - let to_copy_buffer_ptr = *to_copy_buffer_ptr; - let to_copy_buffer_len = *to_copy_buffer_len; - let pointer = *pointer; - let space_left = *space_left; + match state { + CopyDataState::Copy => { + turso_assert!(*src_data_offset < record_buf.len(), "trying to read past end of record buffer: record_offset={} < record_buf.len()={}", src_data_offset, record_buf.len()); + let record_offset_slice = &record_buf[*src_data_offset..]; + let amount_to_copy = + (*space_left_on_cur_page).min(record_offset_slice.len()); + let record_offset_slice_to_copy = &record_offset_slice[..amount_to_copy]; + if let Some(cur_page) = current_overflow_page { + // Copy data into the current overflow page. + turso_assert!( + cur_page.is_loaded(), + "current overflow page is not loaded" + ); + turso_assert!(*dst_data_offset == overflow_page_pointer_size, "data must be copied to offset {overflow_page_pointer_size} on overflow pages, instead tried to copy to offset {dst_data_offset}"); + let contents = cur_page.get_contents(); + let buf = &mut contents.as_ptr() + [*dst_data_offset..*dst_data_offset + amount_to_copy]; + buf.copy_from_slice(record_offset_slice_to_copy); + } else { + // Copy data into the cell payload on the btree page. + let buf = &mut cell_payload + [*dst_data_offset..*dst_data_offset + amount_to_copy]; + buf.copy_from_slice(record_offset_slice_to_copy); + } - // SAFETY: we know to_copy_buffer_ptr is valid because it refers to record_buf which lives at least as long as this function, - // and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - let to_copy_buffer = unsafe { - std::slice::from_raw_parts(to_copy_buffer_ptr, to_copy_buffer_len) - }; - to_copy = space_left.min(to_copy_buffer_len); - // SAFETY: we know 'pointer' is valid because it refers to cell_payload which lives at least as long as this function, - // and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - unsafe { std::ptr::copy(to_copy_buffer_ptr, pointer, to_copy) }; + if record_offset_slice.len() - amount_to_copy == 0 { + if let Some(cur_page) = current_overflow_page { + cur_page.unpin(); // We can safely unpin the current overflow page now. + } + // Everything copied. + break; + } + *state = CopyDataState::AllocateOverflowPage; + *src_data_offset += amount_to_copy; + } + CopyDataState::AllocateOverflowPage => { + let new_overflow_page = return_if_io!(pager.allocate_overflow_page()); + new_overflow_page.pin(); // Pin the current overflow page so the cache won't evict it because we need this page to be in memory for the next iteration of FillCellPayloadState::CopyData. + if let Some(prev_page) = current_overflow_page { + prev_page.unpin(); // We can safely unpin the previous overflow page now. + } - let left = to_copy_buffer.len() - to_copy; - if left == 0 { - break; + turso_assert!( + new_overflow_page.is_loaded(), + "new overflow page is not loaded" + ); + let new_overflow_page_id = new_overflow_page.get().id as u32; + + if let Some(prev_page) = current_overflow_page { + // Update the previous overflow page's "next overflow page" pointer to point to the new overflow page. + turso_assert!( + prev_page.is_loaded(), + "previous overflow page is not loaded" + ); + let contents = prev_page.get_contents(); + let buf = &mut contents.as_ptr()[..overflow_page_pointer_size]; + buf.copy_from_slice(&new_overflow_page_id.to_be_bytes()); + } else { + // Update the cell payload's "next overflow page" pointer to point to the new overflow page. + let first_overflow_page_ptr_offset = + cell_payload.len() - overflow_page_pointer_size; + let buf = &mut cell_payload[first_overflow_page_ptr_offset + ..first_overflow_page_ptr_offset + overflow_page_pointer_size]; + buf.copy_from_slice(&new_overflow_page_id.to_be_bytes()); + } + + *dst_data_offset = overflow_page_pointer_size; + *space_left_on_cur_page = overflow_page_data_size; + *current_overflow_page = Some(new_overflow_page.clone()); + *state = CopyDataState::Copy; } } - - // we still have bytes to add, we will need to allocate new overflow page - // FIXME: handle page cache is full - let overflow_page = return_if_io!(pager.allocate_overflow_page()); - turso_assert!(overflow_page.is_loaded(), "overflow page is not loaded"); - { - let id = overflow_page.get().id as u32; - let contents = overflow_page.get_contents(); - - // TODO: take into account offset here? - let buf = contents.as_ptr(); - let as_bytes = id.to_be_bytes(); - // update pointer to new overflow page - // SAFETY: we know 'pointer_to_next' is valid because it refers to an offset in cell_payload which is less than space_left + 4, - // and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - unsafe { std::ptr::copy(as_bytes.as_ptr(), *pointer_to_next, 4) }; - - *pointer = unsafe { buf.as_mut_ptr().add(4) }; - *pointer_to_next = buf.as_mut_ptr(); - *space_left = usable_space - 4; - } - - *to_copy_buffer_len -= to_copy; - // SAFETY: we know 'to_copy_buffer_ptr' is valid because it refers to record_buf which lives at least as long as this function, - // and that the offset is less than its length, and the underlying bytes are not mutated in FillCellPayloadState::AllocateOverflowPages. - *to_copy_buffer_ptr = unsafe { to_copy_buffer_ptr.add(to_copy) }; } } } + page.unpin(); Ok(IOResult::Done(())) } - /// Returns the maximum payload size (X) that can be stored directly on a b-tree page without spilling to overflow pages. /// /// For table leaf pages: X = usable_size - 35 From b4ee40dd3d048612aa823e738ba2b19e2964b61d Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Sat, 23 Aug 2025 00:25:50 +0300 Subject: [PATCH 04/73] fix tests --- core/storage/btree.rs | 249 +++++++++++++++++++++--------------------- 1 file changed, 125 insertions(+), 124 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index ab263e406..962e5d51a 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7324,7 +7324,7 @@ mod tests { fn add_record( id: usize, pos: usize, - page: &mut PageContent, + page: PageRef, record: ImmutableRecord, conn: &Arc, ) -> Vec { @@ -7333,7 +7333,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page, + page.clone(), Some(id as i64), &mut payload, pos, @@ -7346,7 +7346,7 @@ mod tests { &conn.pager.borrow().clone(), ) .unwrap(); - insert_into_cell(page, &payload, pos, 4096).unwrap(); + insert_into_cell(page.get_contents(), &payload, pos, 4096).unwrap(); payload } @@ -7356,17 +7356,17 @@ mod tests { let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get(); - let page = page.get_contents(); let header_size = 8; let regs = &[Register::Value(Value::Integer(1))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(1, 0, page, record, &conn); - assert_eq!(page.cell_count(), 1); - let free = compute_free_space(page, 4096); + let payload = add_record(1, 0, page.clone(), record, &conn); + let page_contents = page.get_contents(); + assert_eq!(page_contents.cell_count(), 1); + let free = compute_free_space(page_contents, 4096); assert_eq!(free, 4096 - payload.len() - 2 - header_size); let cell_idx = 0; - ensure_cell(page, cell_idx, &payload); + ensure_cell(page_contents, cell_idx, &payload); } struct Cell { @@ -7381,7 +7381,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -7390,22 +7390,22 @@ mod tests { for i in 0..3 { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); } for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } cells.remove(1); - drop_cell(page, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8360,7 +8360,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8370,9 +8370,9 @@ mod tests { for i in 0..total_cells { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); @@ -8382,7 +8382,7 @@ mod tests { let mut new_cells = Vec::new(); for cell in cells { if cell.pos % 2 == 1 { - drop_cell(page, cell.pos - removed, usable_space).unwrap(); + drop_cell(page_contents, cell.pos - removed, usable_space).unwrap(); removed += 1; } else { new_cells.push(cell); @@ -8390,11 +8390,11 @@ mod tests { } let cells = new_cells; for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8812,7 +8812,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8821,28 +8821,28 @@ mod tests { for i in 0..3 { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); } for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } cells.remove(1); - drop_cell(page, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8853,7 +8853,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8863,9 +8863,9 @@ mod tests { for i in 0..total_cells { let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(i, i, page, record, &conn); - assert_eq!(page.cell_count(), i + 1); - let free = compute_free_space(page, usable_space); + let payload = add_record(i, i, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), i + 1); + let free = compute_free_space(page_contents, usable_space); total_size += payload.len() + 2; assert_eq!(free, 4096 - total_size - header_size); cells.push(Cell { pos: i, payload }); @@ -8875,7 +8875,7 @@ mod tests { let mut new_cells = Vec::new(); for cell in cells { if cell.pos % 2 == 1 { - drop_cell(page, cell.pos - removed, usable_space).unwrap(); + drop_cell(page_contents, cell.pos - removed, usable_space).unwrap(); removed += 1; } else { new_cells.push(cell); @@ -8883,13 +8883,13 @@ mod tests { } let cells = new_cells; for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } } @@ -8900,7 +8900,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8915,8 +8915,8 @@ mod tests { match rng.next_u64() % 4 { 0 => { // allow appends with extra place to insert - let cell_idx = rng.next_u64() as usize % (page.cell_count() + 1); - let free = compute_free_space(page, usable_space); + let cell_idx = rng.next_u64() as usize % (page_contents.cell_count() + 1); + let free = compute_free_space(page_contents, usable_space); let regs = &[Register::Value(Value::Integer(i as i64))]; let record = ImmutableRecord::from_registers(regs, regs.len()); let mut payload: Vec = Vec::new(); @@ -8924,7 +8924,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page, + page.clone(), Some(i as i64), &mut payload, cell_idx, @@ -8941,34 +8941,34 @@ mod tests { // do not try to insert overflow pages because they require balancing continue; } - insert_into_cell(page, &payload, cell_idx, 4096).unwrap(); - assert!(page.overflow_cells.is_empty()); + insert_into_cell(page_contents, &payload, cell_idx, 4096).unwrap(); + assert!(page_contents.overflow_cells.is_empty()); total_size += payload.len() + 2; cells.insert(cell_idx, Cell { pos: i, payload }); } 1 => { - if page.cell_count() == 0 { + if page_contents.cell_count() == 0 { continue; } - let cell_idx = rng.next_u64() as usize % page.cell_count(); - let (_, len) = page.cell_get_raw_region(cell_idx, usable_space); - drop_cell(page, cell_idx, usable_space).unwrap(); + let cell_idx = rng.next_u64() as usize % page_contents.cell_count(); + let (_, len) = page_contents.cell_get_raw_region(cell_idx, usable_space); + drop_cell(page_contents, cell_idx, usable_space).unwrap(); total_size -= len + 2; cells.remove(cell_idx); } 2 => { - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); } 3 => { // check cells for (i, cell) in cells.iter().enumerate() { - ensure_cell(page, i, &cell.payload); + ensure_cell(page_contents, i, &cell.payload); } - assert_eq!(page.cell_count(), cells.len()); + assert_eq!(page_contents.cell_count(), cells.len()); } _ => unreachable!(), } - let free = compute_free_space(page, usable_space); + let free = compute_free_space(page_contents, usable_space); assert_eq!(free, 4096 - total_size - header_size); } } @@ -8982,7 +8982,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let mut total_size = 0; @@ -8997,8 +8997,8 @@ mod tests { match rng.next_u64() % 3 { 0 => { // allow appends with extra place to insert - let cell_idx = rng.next_u64() as usize % (page.cell_count() + 1); - let free = compute_free_space(page, usable_space); + let cell_idx = rng.next_u64() as usize % (page_contents.cell_count() + 1); + let free = compute_free_space(page_contents, usable_space); let regs = &[Register::Value(Value::Integer(i))]; let record = ImmutableRecord::from_registers(regs, regs.len()); let mut payload: Vec = Vec::new(); @@ -9006,7 +9006,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page, + page.clone(), Some(i), &mut payload, cell_idx, @@ -9023,8 +9023,8 @@ mod tests { // do not try to insert overflow pages because they require balancing continue; } - insert_into_cell(page, &payload, cell_idx, 4096).unwrap(); - assert!(page.overflow_cells.is_empty()); + insert_into_cell(page_contents, &payload, cell_idx, 4096).unwrap(); + assert!(page_contents.overflow_cells.is_empty()); total_size += payload.len() + 2; cells.push(Cell { pos: i as usize, @@ -9032,21 +9032,21 @@ mod tests { }); } 1 => { - if page.cell_count() == 0 { + if page_contents.cell_count() == 0 { continue; } - let cell_idx = rng.next_u64() as usize % page.cell_count(); - let (_, len) = page.cell_get_raw_region(cell_idx, usable_space); - drop_cell(page, cell_idx, usable_space).unwrap(); + let cell_idx = rng.next_u64() as usize % page_contents.cell_count(); + let (_, len) = page_contents.cell_get_raw_region(cell_idx, usable_space); + drop_cell(page_contents, cell_idx, usable_space).unwrap(); total_size -= len + 2; cells.remove(cell_idx); } 2 => { - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); } _ => unreachable!(), } - let free = compute_free_space(page, usable_space); + let free = compute_free_space(page_contents, usable_space); assert_eq!(free, 4096 - total_size - header_size); } } @@ -9158,14 +9158,14 @@ mod tests { let conn = db.connect().unwrap(); let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let header_size = 8; let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); - let free = compute_free_space(page, usable_space); + let payload = add_record(0, 0, page.clone(), record, &conn); + let free = compute_free_space(page_contents, usable_space); assert_eq!(free, 4096 - payload.len() - 2 - header_size); } @@ -9176,18 +9176,18 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); + let payload = add_record(0, 0, page.clone(), record, &conn); - assert_eq!(page.cell_count(), 1); - defragment_page(page, usable_space, 4).unwrap(); - assert_eq!(page.cell_count(), 1); - let (start, len) = page.cell_get_raw_region(0, usable_space); - let buf = page.as_ptr(); + assert_eq!(page_contents.cell_count(), 1); + defragment_page(page_contents, usable_space, 4).unwrap(); + assert_eq!(page_contents.cell_count(), 1); + let (start, len) = page_contents.cell_get_raw_region(0, usable_space); + let buf = page_contents.as_ptr(); assert_eq!(&payload, &buf[start..start + len]); } @@ -9198,7 +9198,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[ @@ -9206,19 +9206,19 @@ mod tests { Register::Value(Value::Text(Text::new("aaaaaaaa"))), ]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); + let _ = add_record(0, 0, page.clone(), record, &conn); - assert_eq!(page.cell_count(), 1); - drop_cell(page, 0, usable_space).unwrap(); - assert_eq!(page.cell_count(), 0); + assert_eq!(page_contents.cell_count(), 1); + drop_cell(page_contents, 0, usable_space).unwrap(); + assert_eq!(page_contents.cell_count(), 0); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); - assert_eq!(page.cell_count(), 1); + let payload = add_record(0, 0, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), 1); - let (start, len) = page.cell_get_raw_region(0, usable_space); - let buf = page.as_ptr(); + let (start, len) = page_contents.cell_get_raw_region(0, usable_space); + let buf = page_contents.as_ptr(); assert_eq!(&payload, &buf[start..start + len]); } @@ -9229,7 +9229,7 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[ @@ -9237,20 +9237,20 @@ mod tests { Register::Value(Value::Text(Text::new("aaaaaaaa"))), ]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); + let _ = add_record(0, 0, page.clone(), record, &conn); for _ in 0..100 { - assert_eq!(page.cell_count(), 1); - drop_cell(page, 0, usable_space).unwrap(); - assert_eq!(page.cell_count(), 0); + assert_eq!(page_contents.cell_count(), 1); + drop_cell(page_contents, 0, usable_space).unwrap(); + assert_eq!(page_contents.cell_count(), 0); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); - assert_eq!(page.cell_count(), 1); + let payload = add_record(0, 0, page.clone(), record, &conn); + assert_eq!(page_contents.cell_count(), 1); - let (start, len) = page.cell_get_raw_region(0, usable_space); - let buf = page.as_ptr(); + let (start, len) = page_contents.cell_get_raw_region(0, usable_space); + let buf = page_contents.as_ptr(); assert_eq!(&payload, &buf[start..start + len]); } } @@ -9262,23 +9262,23 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let payload = add_record(0, 0, page, record, &conn); + let payload = add_record(0, 0, page.clone(), record, &conn); let regs = &[Register::Value(Value::Integer(1))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(1, 1, page, record, &conn); + let _ = add_record(1, 1, page.clone(), record, &conn); let regs = &[Register::Value(Value::Integer(2))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(2, 2, page, record, &conn); + let _ = add_record(2, 2, page.clone(), record, &conn); - drop_cell(page, 1, usable_space).unwrap(); - drop_cell(page, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); + drop_cell(page_contents, 1, usable_space).unwrap(); - ensure_cell(page, 0, &payload); + ensure_cell(page_contents, 0, &payload); } #[test] @@ -9288,29 +9288,29 @@ mod tests { let page = get_page(2); let page = page.get(); - let page = page.get_contents(); + let page_contents = page.get_contents(); let usable_space = 4096; let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); + let _ = add_record(0, 0, page.clone(), record, &conn); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 0, page, record, &conn); - drop_cell(page, 0, usable_space).unwrap(); + let _ = add_record(0, 0, page.clone(), record, &conn); + drop_cell(page_contents, 0, usable_space).unwrap(); - defragment_page(page, usable_space, 4).unwrap(); + defragment_page(page_contents, usable_space, 4).unwrap(); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 1, page, record, &conn); + let _ = add_record(0, 1, page.clone(), record, &conn); - drop_cell(page, 0, usable_space).unwrap(); + drop_cell(page_contents, 0, usable_space).unwrap(); let regs = &[Register::Value(Value::Integer(0))]; let record = ImmutableRecord::from_registers(regs, regs.len()); - let _ = add_record(0, 1, page, record, &conn); + let _ = add_record(0, 1, page.clone(), record, &conn); } #[test] @@ -9334,21 +9334,21 @@ mod tests { let page = page.get(); defragment(page.get_contents()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(0, page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(0, page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); defragment(page.get_contents()); defragment(page.get_contents()); drop(0, page.get_contents()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(0, page.get_contents()); - insert(0, page.get_contents()); - insert(1, page.get_contents()); - insert(1, page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); + insert(1, page.clone()); + insert(1, page.clone()); + insert(0, page.clone()); drop(3, page.get_contents()); drop(2, page.get_contents()); compute_free_space(page.get_contents(), usable_space); @@ -9379,7 +9379,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page.get().get_contents(), + page.get(), Some(0), &mut payload, 0, @@ -9393,11 +9393,11 @@ mod tests { ) .unwrap(); let page = page.get(); - insert(0, page.get_contents()); + insert(0, page.clone()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); defragment(page.get_contents()); - insert(0, page.get_contents()); + insert(0, page.clone()); drop(2, page.get_contents()); drop(0, page.get_contents()); let free = compute_free_space(page.get_contents(), usable_space); @@ -9465,7 +9465,7 @@ mod tests { run_until_done( || { fill_cell_payload( - page.get().get_contents(), + page.get(), Some(0), &mut payload, 0, @@ -9817,7 +9817,7 @@ mod tests { while compute_free_space(page.get_contents(), pager.usable_space()) >= size as usize + 10 { - insert_cell(i, size, page.get_contents(), pager.clone()); + insert_cell(i, size, page.clone(), pager.clone()); i += 1; size = (rng.next_u64() % 1024) as u16; } @@ -9868,15 +9868,16 @@ mod tests { } } - fn insert_cell(cell_idx: u64, size: u16, contents: &mut PageContent, pager: Rc) { + fn insert_cell(cell_idx: u64, size: u16, page: PageRef, pager: Rc) { let mut payload = Vec::new(); let regs = &[Register::Value(Value::Blob(vec![0; size as usize]))]; let record = ImmutableRecord::from_registers(regs, regs.len()); let mut fill_cell_payload_state = FillCellPayloadState::Start; + let contents = page.get_contents(); run_until_done( || { fill_cell_payload( - contents, + page.clone(), Some(cell_idx as i64), &mut payload, cell_idx as usize, From 53f9c0dc7a959136e6c269b1ca9b658794535691 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Sat, 23 Aug 2025 12:35:29 +0530 Subject: [PATCH 05/73] Add support for lord AEGIS, the fastest and the greatest --- Cargo.lock | 17 ++++ core/Cargo.toml | 1 + core/storage/encryption.rs | 202 +++++++++++++++++++++++++++++++++---- 3 files changed, 201 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 526ecbf51..81a4300c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,6 +27,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "aegis" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2a1c2f54793fee13c334f70557d3bd6a029a9d453ebffd82ba571d139064da8" +dependencies = [ + "cc", + "softaes", +] + [[package]] name = "aes" version = "0.8.4" @@ -3440,6 +3450,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "softaes" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef461faaeb36c340b6c887167a9054a034f6acfc50a014ead26a02b4356b3de" + [[package]] name = "sorted-vec" version = "0.8.6" @@ -3989,6 +4005,7 @@ dependencies = [ name = "turso_core" version = "0.1.4" dependencies = [ + "aegis", "aes", "aes-gcm", "antithesis_sdk", diff --git a/core/Cargo.toml b/core/Cargo.toml index e9f11969a..37c150524 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -77,6 +77,7 @@ bytemuck = "1.23.1" aes-gcm = { version = "0.10.3"} aes = { version = "0.8.4"} turso_parser = { workspace = true } +aegis = "0.9.0" [build-dependencies] chrono = { version = "0.4.38", default-features = false } diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 97836d3d1..81d128f77 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -1,15 +1,13 @@ #![allow(unused_variables, dead_code)] use crate::{LimboError, Result}; +use aegis::aegis256::Aegis256; use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes256Gcm, Key, Nonce, }; use std::ops::Deref; -pub const ENCRYPTION_METADATA_SIZE: usize = 28; pub const ENCRYPTED_PAGE_SIZE: usize = 4096; -pub const ENCRYPTION_NONCE_SIZE: usize = 12; -pub const ENCRYPTION_TAG_SIZE: usize = 16; #[repr(transparent)] #[derive(Clone)] @@ -70,9 +68,65 @@ impl Drop for EncryptionKey { } } +// wrapper struct for AEGIS-256 cipher, because the crate we use is a bit low-level and we add +// some nice abstractions here +// note, the AEGIS has many variants and support for hardware acceleration. Here we just use the +// vanilla version, which is still order of maginitudes faster than AES-GCM in software. Hardware +// based compilation is left for future work. +#[derive(Clone)] +pub struct Aegis256Cipher { + key: EncryptionKey, +} + +impl Aegis256Cipher { + // AEGIS-256 supports both 16 and 32 byte tags, we use the 16 byte variant, it is faster + // and provides sufficient security for our use case. + const TAG_SIZE: usize = 16; + fn new(key: &EncryptionKey) -> Self { + Self { key: key.clone() } + } + + fn encrypt(&self, plaintext: &[u8], ad: &[u8]) -> Result<(Vec, [u8; 32])> { + let nonce = generate_secure_nonce(); + let (ciphertext, tag) = + Aegis256::<16>::new(self.key.as_bytes(), &nonce).encrypt(plaintext, ad); + let mut result = ciphertext; + result.extend_from_slice(&tag); + Ok((result, nonce)) + } + + fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 32], ad: &[u8]) -> Result> { + if ciphertext.len() < Self::TAG_SIZE { + return Err(LimboError::InternalError( + "Ciphertext too short for AEGIS-256".into(), + )); + } + let (ct, tag) = ciphertext.split_at(ciphertext.len() - Self::TAG_SIZE); + let tag_array: [u8; 16] = tag + .try_into() + .map_err(|_| LimboError::InternalError("Invalid tag size for AEGIS-256".into()))?; + + let plaintext = Aegis256::<16>::new(self.key.as_bytes(), nonce) + .decrypt(ct, &tag_array, ad) + .map_err(|_| { + LimboError::InternalError("AEGIS-256 decryption failed: invalid tag".into()) + })?; + Ok(plaintext) + } +} + +impl std::fmt::Debug for Aegis256Cipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Aegis256Cipher") + .field("key", &"") + .finish() + } +} + #[derive(Debug, Clone, Copy, PartialEq)] pub enum CipherMode { Aes256Gcm, + Aegis256, } impl CipherMode { @@ -81,33 +135,43 @@ impl CipherMode { pub fn required_key_size(&self) -> usize { match self { CipherMode::Aes256Gcm => 32, + CipherMode::Aegis256 => 32, } } - /// Returns the nonce size for this cipher mode. Though most AEAD ciphers use 12-byte nonces. + /// Returns the nonce size for this cipher mode. pub fn nonce_size(&self) -> usize { match self { - CipherMode::Aes256Gcm => ENCRYPTION_NONCE_SIZE, + CipherMode::Aes256Gcm => 12, + CipherMode::Aegis256 => 32, } } - /// Returns the authentication tag size for this cipher mode. All common AEAD ciphers use 16-byte tags. + /// Returns the authentication tag size for this cipher mode. pub fn tag_size(&self) -> usize { match self { - CipherMode::Aes256Gcm => ENCRYPTION_TAG_SIZE, + CipherMode::Aes256Gcm => 16, + CipherMode::Aegis256 => 16, } } + + /// Returns the total metadata size (nonce + tag) for this cipher mode. + pub fn metadata_size(&self) -> usize { + self.nonce_size() + self.tag_size() + } } #[derive(Clone)] pub enum Cipher { Aes256Gcm(Box), + Aegis256(Box), } impl std::fmt::Debug for Cipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Cipher::Aes256Gcm(_) => write!(f, "Cipher::Aes256Gcm"), + Cipher::Aegis256(_) => write!(f, "Cipher::Aegis256"), } } } @@ -119,8 +183,7 @@ pub struct EncryptionContext { } impl EncryptionContext { - pub fn new(key: &EncryptionKey) -> Result { - let cipher_mode = CipherMode::Aes256Gcm; + pub fn new(cipher_mode: CipherMode, key: &EncryptionKey) -> Result { let required_size = cipher_mode.required_key_size(); if key.as_slice().len() != required_size { return Err(crate::LimboError::InvalidArgument(format!( @@ -136,6 +199,7 @@ impl EncryptionContext { let cipher_key: &Key = key.as_ref().into(); Cipher::Aes256Gcm(Box::new(Aes256Gcm::new(cipher_key))) } + CipherMode::Aegis256 => Cipher::Aegis256(Box::new(Aegis256Cipher::new(key))), }; Ok(Self { cipher_mode, @@ -147,6 +211,11 @@ impl EncryptionContext { self.cipher_mode } + /// Returns the number of reserved bytes required at the end of each page for encryption metadata. + pub fn required_reserved_bytes(&self) -> u8 { + self.cipher_mode.metadata_size() as u8 + } + #[cfg(feature = "encryption")] pub fn encrypt_page(&self, page: &[u8], page_id: usize) -> Result> { if page_id == 1 { @@ -159,21 +228,26 @@ impl EncryptionContext { ENCRYPTED_PAGE_SIZE, "Page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" ); - let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE..]; + + let metadata_size = self.cipher_mode.metadata_size(); + let reserved_bytes = &page[ENCRYPTED_PAGE_SIZE - metadata_size..]; let reserved_bytes_zeroed = reserved_bytes.iter().all(|&b| b == 0); assert!( reserved_bytes_zeroed, "last reserved bytes must be empty/zero, but found non-zero bytes" ); - let payload = &page[..ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE]; + + let payload = &page[..ENCRYPTED_PAGE_SIZE - metadata_size]; let (encrypted, nonce) = self.encrypt_raw(payload)?; + let nonce_size = self.cipher_mode.nonce_size(); assert_eq!( encrypted.len(), - ENCRYPTED_PAGE_SIZE - nonce.len(), + ENCRYPTED_PAGE_SIZE - nonce_size, "Encrypted page must be exactly {} bytes", - ENCRYPTED_PAGE_SIZE - nonce.len() + ENCRYPTED_PAGE_SIZE - nonce_size ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); result.extend_from_slice(&encrypted); result.extend_from_slice(&nonce); @@ -198,18 +272,21 @@ impl EncryptionContext { "Encrypted page data must be exactly {ENCRYPTED_PAGE_SIZE} bytes" ); - let nonce_start = encrypted_page.len() - ENCRYPTION_NONCE_SIZE; + let nonce_size = self.cipher_mode.nonce_size(); + let nonce_start = encrypted_page.len() - nonce_size; let payload = &encrypted_page[..nonce_start]; let nonce = &encrypted_page[nonce_start..]; let decrypted_data = self.decrypt_raw(payload, nonce)?; + let metadata_size = self.cipher_mode.metadata_size(); assert_eq!( decrypted_data.len(), - ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE, + ENCRYPTED_PAGE_SIZE - metadata_size, "Decrypted page data must be exactly {} bytes", - ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE + ENCRYPTED_PAGE_SIZE - metadata_size ); + let mut result = Vec::with_capacity(ENCRYPTED_PAGE_SIZE); result.extend_from_slice(&decrypted_data); result.resize(ENCRYPTED_PAGE_SIZE, 0); @@ -231,6 +308,11 @@ impl EncryptionContext { .map_err(|e| LimboError::InternalError(format!("Encryption failed: {e:?}")))?; Ok((ciphertext, nonce.to_vec())) } + Cipher::Aegis256(cipher) => { + let ad = b""; + let (ciphertext, nonce) = cipher.encrypt(plaintext, ad)?; + Ok((ciphertext, nonce.to_vec())) + } } } @@ -243,6 +325,16 @@ impl EncryptionContext { })?; Ok(plaintext) } + Cipher::Aegis256(cipher) => { + let nonce_array: [u8; 32] = nonce.try_into().map_err(|_| { + LimboError::InternalError(format!( + "Invalid nonce size for AEGIS-256: expected 32, got {}", + nonce.len() + )) + })?; + let ad = b""; + cipher.decrypt(ciphertext, &nonce_array, ad) + } } } @@ -261,6 +353,14 @@ impl EncryptionContext { } } +fn generate_secure_nonce() -> [u8; 32] { + // use OsRng directly to fill bytes, similar to how AeadCore does it + use aes_gcm::aead::rand_core::RngCore; + let mut nonce = [0u8; 32]; + OsRng.fill_bytes(&mut nonce); + nonce +} + #[cfg(test)] mod tests { use super::*; @@ -268,9 +368,11 @@ mod tests { #[test] #[cfg(feature = "encryption")] - fn test_encrypt_decrypt_round_trip() { + fn test_aes_encrypt_decrypt_round_trip() { let mut rng = rand::thread_rng(); - let data_size = ENCRYPTED_PAGE_SIZE - ENCRYPTION_METADATA_SIZE; + let cipher_mode = CipherMode::Aes256Gcm; + let metadata_size = cipher_mode.metadata_size(); + let data_size = ENCRYPTED_PAGE_SIZE - metadata_size; let page_data = { let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE]; @@ -281,7 +383,7 @@ mod tests { }; let key = EncryptionKey::from_string("alice and bob use encryption on database"); - let ctx = EncryptionContext::new(&key).unwrap(); + let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key).unwrap(); let page_id = 42; let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap(); @@ -293,4 +395,66 @@ mod tests { assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE); assert_eq!(decrypted, page_data); } + + #[test] + #[cfg(feature = "encryption")] + fn test_aegis256_cipher_wrapper() { + let key = EncryptionKey::from_string("alice and bob use AEGIS-256 here!"); + let cipher = Aegis256Cipher::new(&key); + + let plaintext = b"Hello, AEGIS-256!"; + let ad = b"additional data"; + + let (ciphertext, nonce) = cipher.encrypt(plaintext, ad).unwrap(); + assert_eq!(nonce.len(), 32); + assert_ne!(ciphertext[..plaintext.len()], plaintext[..]); + + let decrypted = cipher.decrypt(&ciphertext, &nonce, ad).unwrap(); + assert_eq!(decrypted, plaintext); + } + + #[test] + #[cfg(feature = "encryption")] + fn test_aegis256_raw_encryption() { + let key = EncryptionKey::from_string("alice and bob use AEGIS-256 here!"); + let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap(); + + let plaintext = b"Hello, AEGIS-256!"; + let (ciphertext, nonce) = ctx.encrypt_raw(plaintext).unwrap(); + + assert_eq!(nonce.len(), 32); // AEGIS-256 uses 32-byte nonces + assert_ne!(ciphertext[..plaintext.len()], plaintext[..]); + + let decrypted = ctx.decrypt_raw(&ciphertext, &nonce).unwrap(); + assert_eq!(decrypted, plaintext); + } + + #[test] + #[cfg(feature = "encryption")] + fn test_aegis256_encrypt_decrypt_round_trip() { + let mut rng = rand::thread_rng(); + let cipher_mode = CipherMode::Aegis256; + let metadata_size = cipher_mode.metadata_size(); + let data_size = ENCRYPTED_PAGE_SIZE - metadata_size; + + let page_data = { + let mut page = vec![0u8; ENCRYPTED_PAGE_SIZE]; + page.iter_mut() + .take(data_size) + .for_each(|byte| *byte = rng.gen()); + page + }; + + let key = EncryptionKey::from_string("alice and bob use AEGIS-256 for pages!"); + let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap(); + + let page_id = 42; + let encrypted = ctx.encrypt_page(&page_data, page_id).unwrap(); + assert_eq!(encrypted.len(), ENCRYPTED_PAGE_SIZE); + assert_ne!(&encrypted[..data_size], &page_data[..data_size]); + + let decrypted = ctx.decrypt_page(&encrypted, page_id).unwrap(); + assert_eq!(decrypted.len(), ENCRYPTED_PAGE_SIZE); + assert_eq!(decrypted, page_data); + } } From a4b9c33b8163937a75314847954f44aa0db2a01b Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Sat, 23 Aug 2025 12:44:09 +0530 Subject: [PATCH 06/73] Use the new API to init cipher --- core/storage/pager.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 4512b93c7..85f33775e 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -28,9 +28,7 @@ use super::btree::{btree_init_page, BTreePage}; use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey}; use super::sqlite3_ondisk::begin_write_btree_page; use super::wal::CheckpointMode; -use crate::storage::encryption::{ - EncryptionKey, EncryptionContext, ENCRYPTION_METADATA_SIZE, -}; +use crate::storage::encryption::{CipherMode, EncryptionContext, EncryptionKey}; /// SQLite's default maximum page count const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe; @@ -1726,8 +1724,8 @@ impl Pager { default_header.database_size = 1.into(); // if a key is set, then we will reserve space for encryption metadata - if self.encryption_ctx.borrow().is_some() { - default_header.reserved_space = ENCRYPTION_METADATA_SIZE as u8; + if let Some(ref ctx) = *self.encryption_ctx.borrow() { + default_header.reserved_space = ctx.required_reserved_bytes() } if let Some(size) = self.page_size.get() { @@ -2112,7 +2110,7 @@ impl Pager { } pub fn set_encryption_context(&self, key: &EncryptionKey) { - let encryption_ctx = EncryptionContext::new(key).unwrap(); + let encryption_ctx = EncryptionContext::new(CipherMode::Aegis256, key).unwrap(); self.encryption_ctx.replace(Some(encryption_ctx.clone())); let Some(wal) = self.wal.as_ref() else { return }; wal.borrow_mut().set_encryption_context(encryption_ctx) From 77a4e96022f5e6919612487b70bc5606e283481d Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Sat, 23 Aug 2025 13:41:19 +0530 Subject: [PATCH 07/73] run encryption tests in CI --- .github/workflows/rust.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 12073def9..b354dd8e3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -39,7 +39,9 @@ jobs: - name: Build run: cargo build --verbose - name: Test Encryption - run: cargo test --features encryption --color=always --test integration_tests query_processing::encryption + run: | + cargo test --features encryption --color=always --test integration_tests query_processing::encryption + cargo test --features encryption --color=always --lib storage::encryption - name: Test env: RUST_LOG: ${{ runner.debug && 'turso_core::storage=trace' || '' }} From 011f8781589301e99f835d4b88e8ca1a6e0a25d8 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Sun, 24 Aug 2025 16:21:06 +0530 Subject: [PATCH 08/73] make clippy bro happy --- core/lib.rs | 2 +- core/storage/database.rs | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 5ba3a4faa..e8842527c 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -75,7 +75,7 @@ use std::{ }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; -pub use storage::encryption::{EncryptionKey, EncryptionContext}; +pub use storage::encryption::{EncryptionContext, EncryptionKey}; use storage::page_cache::DumbLruPageCache; use storage::pager::{AtomicDbState, DbState}; use storage::sqlite3_ondisk::PageSize; diff --git a/core/storage/database.rs b/core/storage/database.rs index d608558fc..b13962d30 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -184,11 +184,7 @@ impl DatabaseFile { } } -fn encrypt_buffer( - page_idx: usize, - buffer: Arc, - ctx: &EncryptionContext, -) -> Arc { +fn encrypt_buffer(page_idx: usize, buffer: Arc, ctx: &EncryptionContext) -> Arc { let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); Arc::new(Buffer::new(encrypted_data.to_vec())) } From 2c6fa76437f3dcc6a57b4e70960b244c9acd2959 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sun, 24 Aug 2025 14:13:20 +0300 Subject: [PATCH 09/73] cargo fmt --- core/lib.rs | 2 +- core/storage/database.rs | 6 +----- core/storage/pager.rs | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 5ba3a4faa..e8842527c 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -75,7 +75,7 @@ use std::{ }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; -pub use storage::encryption::{EncryptionKey, EncryptionContext}; +pub use storage::encryption::{EncryptionContext, EncryptionKey}; use storage::page_cache::DumbLruPageCache; use storage::pager::{AtomicDbState, DbState}; use storage::sqlite3_ondisk::PageSize; diff --git a/core/storage/database.rs b/core/storage/database.rs index d608558fc..b13962d30 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -184,11 +184,7 @@ impl DatabaseFile { } } -fn encrypt_buffer( - page_idx: usize, - buffer: Arc, - ctx: &EncryptionContext, -) -> Arc { +fn encrypt_buffer(page_idx: usize, buffer: Arc, ctx: &EncryptionContext) -> Arc { let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); Arc::new(Buffer::new(encrypted_data.to_vec())) } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 4512b93c7..b4dd8f2f0 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -28,9 +28,7 @@ use super::btree::{btree_init_page, BTreePage}; use super::page_cache::{CacheError, CacheResizeResult, DumbLruPageCache, PageCacheKey}; use super::sqlite3_ondisk::begin_write_btree_page; use super::wal::CheckpointMode; -use crate::storage::encryption::{ - EncryptionKey, EncryptionContext, ENCRYPTION_METADATA_SIZE, -}; +use crate::storage::encryption::{EncryptionContext, EncryptionKey, ENCRYPTION_METADATA_SIZE}; /// SQLite's default maximum page count const DEFAULT_MAX_PAGE_COUNT: u32 = 0xfffffffe; From ea2192c332a1ae1e4814ac111e203b6422aacf3d Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sun, 24 Aug 2025 14:05:48 +0300 Subject: [PATCH 10/73] sqlite3: Implement sqlite3_get_autocommit() --- sqlite3/src/lib.rs | 13 +++++- sqlite3/tests/compat/mod.rs | 80 +++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 32a2976e3..2fbf464c6 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -378,8 +378,17 @@ pub unsafe extern "C" fn sqlite3_deserialize( } #[no_mangle] -pub unsafe extern "C" fn sqlite3_get_autocommit(_db: *mut sqlite3) -> ffi::c_int { - stub!(); +pub unsafe extern "C" fn sqlite3_get_autocommit(db: *mut sqlite3) -> ffi::c_int { + if db.is_null() { + return 1; + } + let db: &mut sqlite3 = &mut *db; + let inner = db.inner.lock().unwrap(); + if inner.conn.get_auto_commit() { + 1 + } else { + 0 + } } #[no_mangle] diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index 2192a89f9..df5fca2b0 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -71,6 +71,7 @@ extern "C" { fn sqlite3_column_blob(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_void; fn sqlite3_column_type(stmt: *mut sqlite3_stmt, idx: i32) -> i32; fn sqlite3_column_decltype(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; + fn sqlite3_get_autocommit(db: *mut sqlite3) -> i32; } const SQLITE_OK: i32 = 0; @@ -986,6 +987,85 @@ mod tests { } } + #[test] + fn test_get_autocommit() { + unsafe { + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Should be in autocommit mode by default + assert_eq!(sqlite3_get_autocommit(db), 1); + + // Begin a transaction + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"BEGIN".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Should NOT be in autocommit mode during transaction + assert_eq!(sqlite3_get_autocommit(db), 0); + + // Create a table within the transaction + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2( + db, + c"CREATE TABLE test (id INTEGER PRIMARY KEY)".as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Still not in autocommit mode + assert_eq!(sqlite3_get_autocommit(db), 0); + + // Commit the transaction + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"COMMIT".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Should be back in autocommit mode after commit + assert_eq!(sqlite3_get_autocommit(db), 1); + + // Test with ROLLBACK + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"BEGIN".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!(sqlite3_get_autocommit(db), 0); + + let mut stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"ROLLBACK".as_ptr(), -1, &mut stmt, ptr::null_mut()), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + // Should be back in autocommit mode after rollback + assert_eq!(sqlite3_get_autocommit(db), 1); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } + #[test] fn test_wal_checkpoint() { let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); From 9d2f26bb045065f25b8a93f9fa822f1600c4e36f Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sun, 24 Aug 2025 16:58:06 +0300 Subject: [PATCH 11/73] sqlite3: Implement sqlite3_clear_bindings() --- core/lib.rs | 4 +++ core/vdbe/mod.rs | 4 +++ sqlite3/src/lib.rs | 12 +++++++ sqlite3/tests/compat/mod.rs | 65 +++++++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+) diff --git a/core/lib.rs b/core/lib.rs index e8842527c..843c40bc3 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -2170,6 +2170,10 @@ impl Statement { self.state.bind_at(index, value); } + pub fn clear_bindings(&mut self) { + self.state.clear_bindings(); + } + pub fn reset(&mut self) { self.state.reset(); } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index b99c17cd5..5a9cbe646 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -335,6 +335,10 @@ impl ProgramState { self.parameters.insert(index, value); } + pub fn clear_bindings(&mut self) { + self.parameters.clear(); + } + pub fn get_parameter(&self, index: NonZero) -> Value { self.parameters.get(&index).cloned().unwrap_or(Value::Null) } diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 32a2976e3..dc7d3491e 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -702,6 +702,18 @@ pub unsafe extern "C" fn sqlite3_bind_blob( SQLITE_OK } +#[no_mangle] +pub unsafe extern "C" fn sqlite3_clear_bindings(stmt: *mut sqlite3_stmt) -> ffi::c_int { + if stmt.is_null() { + return SQLITE_MISUSE; + } + + let stmt_ref = &mut *stmt; + stmt_ref.stmt.clear_bindings(); + + SQLITE_OK +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_column_type( stmt: *mut sqlite3_stmt, diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index 2192a89f9..96f12a991 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -27,6 +27,7 @@ extern "C" { tail: *mut *const libc::c_char, ) -> i32; fn sqlite3_step(stmt: *mut sqlite3_stmt) -> i32; + fn sqlite3_reset(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_wal_checkpoint(db: *mut sqlite3, db_name: *const libc::c_char) -> i32; fn sqlite3_wal_checkpoint_v2( @@ -49,6 +50,7 @@ extern "C" { fn sqlite3_bind_int(stmt: *mut sqlite3_stmt, idx: i32, val: i64) -> i32; fn sqlite3_bind_parameter_count(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_bind_parameter_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; + fn sqlite3_clear_bindings(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_column_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; fn sqlite3_last_insert_rowid(db: *mut sqlite3) -> i32; fn sqlite3_column_count(stmt: *mut sqlite3_stmt) -> i32; @@ -1095,4 +1097,67 @@ mod tests { } } } + + #[test] + fn test_sqlite3_clear_bindings() { + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + let mut stmt: *mut sqlite3_stmt = ptr::null_mut(); + + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"CREATE TABLE person (id INTEGER, name TEXT, age INTEGER)".as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"INSERT INTO person (id, name, age) VALUES (1, 'John', 25), (2, 'Jane', 30)" + .as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT * FROM person WHERE id = ? AND age > ?".as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + + // Bind parameters - should find John (id=1, age=25 > 20) + assert_eq!(sqlite3_bind_int(stmt, 1, 1), SQLITE_OK); + assert_eq!(sqlite3_bind_int(stmt, 2, 20), SQLITE_OK); + assert_eq!(sqlite3_step(stmt), SQLITE_ROW); + assert_eq!(sqlite3_column_int(stmt, 0), 1); + assert_eq!(sqlite3_column_int(stmt, 2), 25); + + // Reset and clear bindings, query should return no rows + assert_eq!(sqlite3_reset(stmt), SQLITE_OK); + assert_eq!(sqlite3_clear_bindings(stmt), SQLITE_OK); + assert_eq!(sqlite3_step(stmt), SQLITE_DONE); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } } From c428ff06b2d9b5d808115f9e66fe2b31da875164 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sun, 24 Aug 2025 20:10:31 +0300 Subject: [PATCH 12/73] sqlite3: Implement sqlite3_bind_parameter_index() --- core/lib.rs | 4 ++++ sqlite3/include/sqlite3.h | 2 ++ sqlite3/src/lib.rs | 22 +++++++++++++++++++++ sqlite3/tests/compat/mod.rs | 39 +++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+) diff --git a/core/lib.rs b/core/lib.rs index 843c40bc3..e8e629887 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -2166,6 +2166,10 @@ impl Statement { self.program.parameters.count() } + pub fn parameter_index(&self, name: &str) -> Option> { + self.program.parameters.index(name) + } + pub fn bind_at(&mut self, index: NonZero, value: Value) { self.state.bind_at(index, value); } diff --git a/sqlite3/include/sqlite3.h b/sqlite3/include/sqlite3.h index 21695d4f1..68a8944e5 100644 --- a/sqlite3/include/sqlite3.h +++ b/sqlite3/include/sqlite3.h @@ -153,6 +153,8 @@ int sqlite3_bind_parameter_count(sqlite3_stmt *_stmt); const char *sqlite3_bind_parameter_name(sqlite3_stmt *_stmt, int _idx); +int sqlite3_bind_parameter_index(sqlite3_stmt *_stmt, const char *_name); + int sqlite3_bind_null(sqlite3_stmt *_stmt, int _idx); int sqlite3_bind_int64(sqlite3_stmt *_stmt, int _idx, int64_t _val); diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 89facaec3..b7bd6529a 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -538,6 +538,28 @@ pub unsafe extern "C" fn sqlite3_bind_parameter_name( } } +#[no_mangle] +pub unsafe extern "C" fn sqlite3_bind_parameter_index( + stmt: *mut sqlite3_stmt, + name: *const ffi::c_char, +) -> ffi::c_int { + if stmt.is_null() || name.is_null() { + return 0; + } + + let stmt = &*stmt; + let name_str = match CStr::from_ptr(name).to_str() { + Ok(s) => s, + Err(_) => return 0, + }; + + if let Some(index) = stmt.stmt.parameter_index(name_str) { + index.get() as ffi::c_int + } else { + 0 + } +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_bind_null(stmt: *mut sqlite3_stmt, idx: ffi::c_int) -> ffi::c_int { if stmt.is_null() { diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index c1a92f7a8..a40ae7538 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -50,6 +50,7 @@ extern "C" { fn sqlite3_bind_int(stmt: *mut sqlite3_stmt, idx: i32, val: i64) -> i32; fn sqlite3_bind_parameter_count(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_bind_parameter_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; + fn sqlite3_bind_parameter_index(stmt: *mut sqlite3_stmt, name: *const libc::c_char) -> i32; fn sqlite3_clear_bindings(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_column_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; fn sqlite3_last_insert_rowid(db: *mut sqlite3) -> i32; @@ -1240,4 +1241,42 @@ mod tests { assert_eq!(sqlite3_close(db), SQLITE_OK); } } + + #[test] + fn test_sqlite3_bind_parameter_index() { + const SQLITE_OK: i32 = 0; + + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + let mut stmt: *mut sqlite3_stmt = ptr::null_mut(); + + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + assert_eq!( + sqlite3_prepare_v2( + db, + c"SELECT * FROM sqlite_master WHERE name = :table_name AND type = :object_type" + .as_ptr(), + -1, + &mut stmt, + ptr::null_mut() + ), + SQLITE_OK + ); + + let index1 = sqlite3_bind_parameter_index(stmt, c":table_name".as_ptr()); + assert_eq!(index1, 1); + + let index2 = sqlite3_bind_parameter_index(stmt, c":object_type".as_ptr()); + assert_eq!(index2, 2); + + let index3 = sqlite3_bind_parameter_index(stmt, c":nonexistent".as_ptr()); + assert_eq!(index3, 0); + + let index4 = sqlite3_bind_parameter_index(stmt, ptr::null()); + assert_eq!(index4, 0); + + assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); + } + } } From 37cebb066990cc46ffa8c05d932cd6373478c526 Mon Sep 17 00:00:00 2001 From: bit-aloo Date: Sun, 24 Aug 2025 22:59:47 +0530 Subject: [PATCH 13/73] fix(clippy): remove duplicate arc_with_non_send_sync attribute in wal.rs --- core/storage/wal.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 3b07d80f8..825c21bea 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1,4 +1,3 @@ -#![allow(clippy::arc_with_non_send_sync)] #![allow(clippy::not_unsafe_ptr_arg_deref)] use std::array; From 543025f57a8398a37ad4212366d4e1d57cdd02b6 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 01:32:41 +0530 Subject: [PATCH 14/73] rename encryption PRAGMA key to `hexkey` --- core/pragma.rs | 2 +- parser/src/ast.rs | 7 +++---- tests/integration/query_processing/encryption.rs | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/core/pragma.rs b/core/pragma.rs index f1a77fa66..0ae48b97a 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -109,7 +109,7 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { FreelistCount => Pragma::new(PragmaFlags::Result0, &["freelist_count"]), EncryptionKey => Pragma::new( PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, - &["key"], + &["hexkey"], ), } } diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 440685014..0c9de857c 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1221,10 +1221,9 @@ pub enum PragmaName { IntegrityCheck, /// `journal_mode` pragma JournalMode, - /// encryption key for encrypted databases. This is just called `key` because most - /// extensions use this name instead of `encryption_key`. - #[strum(serialize = "key")] - #[cfg_attr(feature = "serde", serde(rename = "key"))] + /// encryption key for encrypted databases, specified as hexadecimal string. + #[strum(serialize = "hexkey")] + #[cfg_attr(feature = "serde", serde(rename = "hexkey"))] EncryptionKey, /// Noop as per SQLite docs LegacyFileFormat, diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index c286bde7c..92b43cd20 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -16,7 +16,7 @@ fn test_per_page_encryption() -> anyhow::Result<()> { run_query( &tmp_db, &conn, - "PRAGMA key = 'super secret key for encryption';", + "PRAGMA hexkey = 'super secret key for encryption';", )?; run_query( &tmp_db, @@ -58,7 +58,7 @@ fn test_per_page_encryption() -> anyhow::Result<()> { run_query( &existing_db, &conn, - "PRAGMA key = 'super secret key for encryption';", + "PRAGMA hexkey = 'super secret key for encryption';", )?; run_query_on_row(&existing_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); From 0308374d3a977728e71cd91ca9ac3246e2057176 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 01:36:05 +0530 Subject: [PATCH 15/73] Use proper hexadecimal key for encryption Added `from_hex_string` which gets us `EncryptionKey` from a hex string. Now we can use securely generated keys, like from openssl $ openssl rand -hex 32 --- core/storage/encryption.rs | 13 +++++++++++++ core/translate/pragma.rs | 2 +- tests/integration/query_processing/encryption.rs | 4 ++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 81d128f77..e5e108a6b 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -26,6 +26,19 @@ impl EncryptionKey { Self(key) } + pub fn from_hex_string(s: &str) -> Result { + let hex_str = s.trim(); + let bytes = hex::decode(hex_str) + .map_err(|e| LimboError::InvalidArgument(format!("Invalid hex string: {}", e)))?; + let key: [u8; 32] = bytes.try_into().map_err(|v: Vec| { + LimboError::InvalidArgument(format!( + "Hex string must decode to exactly 32 bytes, got {}", + v.len() + )) + })?; + Ok(Self(key)) + } + pub fn as_bytes(&self) -> &[u8; 32] { &self.0 } diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index 3ef8b84a5..b53c7af5b 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -314,7 +314,7 @@ fn update_pragma( ), PragmaName::EncryptionKey => { let value = parse_string(&value)?; - let key = EncryptionKey::from_string(&value); + let key = EncryptionKey::from_hex_string(&value)?; connection.set_encryption_key(key); Ok((program, TransactionMode::None)) } diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index 92b43cd20..82bef1ef1 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -16,7 +16,7 @@ fn test_per_page_encryption() -> anyhow::Result<()> { run_query( &tmp_db, &conn, - "PRAGMA hexkey = 'super secret key for encryption';", + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", )?; run_query( &tmp_db, @@ -58,7 +58,7 @@ fn test_per_page_encryption() -> anyhow::Result<()> { run_query( &existing_db, &conn, - "PRAGMA hexkey = 'super secret key for encryption';", + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", )?; run_query_on_row(&existing_db, &conn, "SELECT * FROM test", |row: &Row| { assert_eq!(row.get::(0).unwrap(), 1); From 279bcd0869557be6f68f43053de40a25e8ce3f8d Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 01:46:44 +0530 Subject: [PATCH 16/73] Remove unsecure `EncryptionKey::from_string` --- core/storage/encryption.rs | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index e5e108a6b..e8b7f1181 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -18,14 +18,6 @@ impl EncryptionKey { Self(key) } - pub fn from_string(s: &str) -> Self { - let mut key = [0u8; 32]; - let bytes = s.as_bytes(); - let len = bytes.len().min(32); - key[..len].copy_from_slice(&bytes[..len]); - Self(key) - } - pub fn from_hex_string(s: &str) -> Result { let hex_str = s.trim(); let bytes = hex::decode(hex_str) @@ -379,6 +371,13 @@ mod tests { use super::*; use rand::Rng; + fn generate_random_hex_key() -> String { + let mut rng = rand::thread_rng(); + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + hex::encode(bytes) + } + #[test] #[cfg(feature = "encryption")] fn test_aes_encrypt_decrypt_round_trip() { @@ -395,7 +394,7 @@ mod tests { page }; - let key = EncryptionKey::from_string("alice and bob use encryption on database"); + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); let ctx = EncryptionContext::new(CipherMode::Aes256Gcm, &key).unwrap(); let page_id = 42; @@ -412,7 +411,7 @@ mod tests { #[test] #[cfg(feature = "encryption")] fn test_aegis256_cipher_wrapper() { - let key = EncryptionKey::from_string("alice and bob use AEGIS-256 here!"); + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); let cipher = Aegis256Cipher::new(&key); let plaintext = b"Hello, AEGIS-256!"; @@ -429,7 +428,7 @@ mod tests { #[test] #[cfg(feature = "encryption")] fn test_aegis256_raw_encryption() { - let key = EncryptionKey::from_string("alice and bob use AEGIS-256 here!"); + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap(); let plaintext = b"Hello, AEGIS-256!"; @@ -458,7 +457,7 @@ mod tests { page }; - let key = EncryptionKey::from_string("alice and bob use AEGIS-256 for pages!"); + let key = EncryptionKey::from_hex_string(&generate_random_hex_key()).unwrap(); let ctx = EncryptionContext::new(CipherMode::Aegis256, &key).unwrap(); let page_id = 42; From 370da9fa590e73759d7ac5dc6464b59b92ec6cfa Mon Sep 17 00:00:00 2001 From: Alex Miller Date: Sun, 24 Aug 2025 13:28:42 -0700 Subject: [PATCH 17/73] ANALYZE creates sqlite_stat1 if it doesn't exist This change replaces a bail_parse_error!() when sqlite_stat1 doesn't exist with the appropriate codegen to create the table, and handle both cases of the table existing or not existing. SQLite's codegen looks like: sqlite> create table stat_test(a,b,c); sqlite> explain analyze stat_test; addr opcode p1 p2 p3 p4 p5 comment ---- ------------- ---- ---- ---- ------------- -- ------------- 0 Init 0 40 0 0 Start at 40 1 ReadCookie 0 3 2 0 2 If 3 5 0 0 3 SetCookie 0 2 4 0 4 SetCookie 0 5 1 0 5 CreateBtree 0 2 1 0 r[2]=root iDb=0 flags=1 6 OpenWrite 0 1 0 5 0 root=1 iDb=0 7 NewRowid 0 1 0 0 r[1]=rowid 8 Blob 6 3 0 0 r[3]= (len=6) 9 Insert 0 3 1 8 intkey=r[1] data=r[3] 10 Close 0 0 0 0 11 Close 0 0 0 0 12 Null 0 4 5 0 r[4..5]=NULL 13 Noop 4 0 4 0 14 OpenWrite 3 1 0 5 0 root=1 iDb=0; sqlite_master 15 SeekRowid 3 17 1 0 intkey=r[1] 16 Rowid 3 5 0 0 r[5]= rowid of 3 17 IsNull 5 26 0 0 if r[5]==NULL goto 26 18 String8 0 6 0 table 0 r[6]='table' 19 String8 0 7 0 sqlite_stat1 0 r[7]='sqlite_stat1' 20 String8 0 8 0 sqlite_stat1 0 r[8]='sqlite_stat1' 21 Copy 2 9 0 0 r[9]=r[2] 22 String8 0 10 0 CREATE TABLE sqlite_stat1(tbl,idx,stat) 0 r[10]='CREATE TABLE sqlite_stat1(tbl,idx,stat)' 23 MakeRecord 6 5 4 BBBDB 0 r[4]=mkrec(r[6..10]) 24 Delete 3 68 5 0 25 Insert 3 4 5 0 intkey=r[5] data=r[4] 26 SetCookie 0 1 2 0 27 ParseSchema 0 0 0 tbl_name='sqlite_stat1' AND type!='trigger' 0 28 OpenWrite 0 2 0 3 16 root=2 iDb=0; sqlite_stat1 29 OpenRead 5 2 0 3 0 root=2 iDb=0; stat_test 30 String8 0 18 0 stat_test 0 r[18]='stat_test'; stat_test 31 Count 5 20 0 0 r[20]=count() 32 IfNot 20 37 0 0 33 Null 0 19 0 0 r[19]=NULL 34 MakeRecord 18 3 16 BBB 0 r[16]=mkrec(r[18..20]) 35 NewRowid 0 12 0 0 r[12]=rowid 36 Insert 0 16 12 8 intkey=r[12] data=r[16] 37 LoadAnalysis 0 0 0 0 38 Expire 0 0 0 0 39 Halt 0 0 0 0 40 Transaction 0 1 1 0 1 usesStmtJournal=1 41 Goto 0 1 0 0 And now Turso's looks like: turso> create table stat_test(a,b,c); turso> explain analyze stat_test; addr opcode p1 p2 p3 p4 p5 comment ---- ----------------- ---- ---- ---- ------------- -- ------- 0 Init 0 23 0 0 Start at 23 1 Null 0 1 0 0 r[1]=NULL 2 CreateBtree 0 2 1 0 r[2]=root iDb=0 flags=1 3 OpenWrite 0 1 0 0 root=1; iDb=0 4 NewRowid 0 3 0 0 r[3]=rowid 5 String8 0 4 0 table 0 r[4]='table' 6 String8 0 5 0 sqlite_stat1 0 r[5]='sqlite_stat1' 7 String8 0 6 0 sqlite_stat1 0 r[6]='sqlite_stat1' 8 Copy 2 7 0 0 r[7]=r[2] 9 String8 0 8 0 CREATE TABLE sqlite_stat1(tbl,idx,stat) 0 r[8]='CREATE TABLE sqlite_stat1(tbl,idx,stat)' 10 MakeRecord 4 5 9 0 r[9]=mkrec(r[4..8]) 11 Insert 0 9 3 sqlite_stat1 0 intkey=r[3] data=r[9] 12 ParseSchema 0 0 0 tbl_name = 'sqlite_stat1' AND type != 'trigger' 0 tbl_name = 'sqlite_stat1' AND type != 'trigger' 13 OpenWrite 1 2 0 0 root=2; iDb=0 14 OpenRead 2 2 0 0 =stat_test, root=2, iDb=0 15 String8 0 12 0 stat_test 0 r[12]='stat_test' 16 Count 2 14 0 0 17 IfNot 14 22 0 0 if !r[14] goto 22 18 Null 0 13 0 0 r[13]=NULL 19 MakeRecord 12 3 11 0 r[11]=mkrec(r[12..14]) 20 NewRowid 1 10 0 0 r[10]=rowid 21 Insert 1 11 10 sqlite_stat1 0 intkey=r[10] data=r[11] 22 Halt 0 0 0 0 23 Goto 0 1 0 0 The notable difference in size is following the same codegen difference in CREATE TABLE, where sqlite's odd dance of adding a placeholder entry which is immediately replaced is instead done in tursodb as just inserting the correct row in the first place. Aside from lines 6-13 of sqlite's vdbe being missing, there's still the lack of LoadAnalysis, Expire, and Cookie management. --- core/translate/analyze.rs | 86 +++++++++++++++++++++++++++++++++++---- core/translate/mod.rs | 2 +- testing/analyze.test | 18 ++------ 3 files changed, 82 insertions(+), 24 deletions(-) diff --git a/core/translate/analyze.rs b/core/translate/analyze.rs index 4b72b1457..0d8f8de4e 100644 --- a/core/translate/analyze.rs +++ b/core/translate/analyze.rs @@ -1,19 +1,27 @@ +use std::sync::Arc; + use turso_parser::ast; use crate::{ bail_parse_error, - schema::Schema, + schema::{BTreeTable, Schema}, + storage::pager::CreateBTreeFlags, + translate::{ + emitter::Resolver, + schema::{emit_schema_entry, SchemaEntryType, SQLITE_TABLEID}, + }, util::normalize_ident, vdbe::{ builder::{CursorType, ProgramBuilder}, - insn::{Insn, RegisterOrLiteral::*}, + insn::{Insn, RegisterOrLiteral}, }, - Result, + Result, SymbolTable, }; pub fn translate_analyze( target_opt: Option, schema: &Schema, + syms: &SymbolTable, mut program: ProgramBuilder, ) -> Result { let Some(target) = target_opt else { @@ -34,7 +42,15 @@ pub fn translate_analyze( dest_end: None, }); + // After preparing/creating sqlite_stat1, we need to OpenWrite it, and how we acquire + // the necessary BTreeTable for cursor creation and root page for the instruction changes + // depending on which path we take. + let sqlite_stat1_btreetable: Arc; + let sqlite_stat1_source: RegisterOrLiteral<_>; + if let Some(sqlite_stat1) = schema.get_btree_table("sqlite_stat1") { + sqlite_stat1_btreetable = sqlite_stat1.clone(); + sqlite_stat1_source = RegisterOrLiteral::Literal(sqlite_stat1.root_page); // sqlite_stat1 already exists, so we need to remove the row // corresponding to the stats for the table which we're about to // ANALYZE. SQLite implements this as a full table scan over @@ -43,7 +59,7 @@ pub fn translate_analyze( let cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_stat1.clone())); program.emit_insn(Insn::OpenWrite { cursor_id, - root_page: Literal(sqlite_stat1.root_page), + root_page: RegisterOrLiteral::Literal(sqlite_stat1.root_page), db: 0, }); let after_loop = program.allocate_label(); @@ -89,7 +105,61 @@ pub fn translate_analyze( }); program.preassign_label_to_next_insn(after_loop); } else { - bail_parse_error!("ANALYZE without an existing sqlite_stat1 is not supported"); + // FIXME: Emit ReadCookie 0 3 2 + // FIXME: Emit If 3 +2 0 + // FIXME: Emit SetCookie 0 2 4 + // FIXME: Emit SetCookie 0 5 1 + + // See the large comment in schema.rs:translate_create_table about + // deviating from SQLite codegen, as the same deviation is being done + // here. + + // TODO: this code half-copies translate_create_table, because there's + // no way to get the table_root_reg back out, and it's needed for later + // codegen to open the table we just created. It's worth a future + // refactoring to remove the duplication one the rest of ANALYZE is + // implemented. + let table_root_reg = program.alloc_register(); + program.emit_insn(Insn::CreateBtree { + db: 0, + root: table_root_reg, + flags: CreateBTreeFlags::new_table(), + }); + let sql = "CREATE TABLE sqlite_stat1(tbl,idx,stat)"; + // The root_page==0 is false, but we don't rely on it, and there's no + // way to initialize it with a correct value. + sqlite_stat1_btreetable = Arc::new(BTreeTable::from_sql(sql, 0)?); + sqlite_stat1_source = RegisterOrLiteral::Register(table_root_reg); + + let table = schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let sqlite_schema_cursor_id = + program.alloc_cursor_id(CursorType::BTreeTable(table.clone())); + program.emit_insn(Insn::OpenWrite { + cursor_id: sqlite_schema_cursor_id, + root_page: 1usize.into(), + db: 0, + }); + + let resolver = Resolver::new(schema, syms); + // Add the table entry to sqlite_schema + emit_schema_entry( + &mut program, + &resolver, + sqlite_schema_cursor_id, + None, + SchemaEntryType::Table, + "sqlite_stat1", + "sqlite_stat1", + table_root_reg, + Some(sql.to_string()), + )?; + //FIXME: Emit SetCookie? + let parse_schema_where_clause = + "tbl_name = 'sqlite_stat1' AND type != 'trigger'".to_string(); + program.emit_insn(Insn::ParseSchema { + db: sqlite_schema_cursor_id, + where_clause: Some(parse_schema_where_clause), + }); }; if target_schema.columns().iter().any(|c| c.primary_key) { @@ -100,13 +170,11 @@ pub fn translate_analyze( } // Count the number of rows in the target table, and insert it into sqlite_stat1. - let sqlite_stat1 = schema - .get_btree_table("sqlite_stat1") - .expect("sqlite_stat1 either pre-existed or was just created"); + let sqlite_stat1 = sqlite_stat1_btreetable; let stat_cursor = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_stat1.clone())); program.emit_insn(Insn::OpenWrite { cursor_id: stat_cursor, - root_page: Literal(sqlite_stat1.root_page), + root_page: sqlite_stat1_source, db: 0, }); let target_cursor = program.alloc_cursor_id(CursorType::BTreeTable(target_btree.clone())); diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 86cd60f52..7d29a8173 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -148,7 +148,7 @@ pub fn translate_inner( ast::Stmt::AlterTable(alter) => { translate_alter_table(alter, syms, schema, program, connection, input)? } - ast::Stmt::Analyze { name } => translate_analyze(name, schema, program)?, + ast::Stmt::Analyze { name } => translate_analyze(name, schema, syms, program)?, ast::Stmt::Attach { expr, db_name, key } => { attach::translate_attach(&expr, &db_name, &key, schema, syms, program)? } diff --git a/testing/analyze.test b/testing/analyze.test index a1761bf45..7d7f95066 100755 --- a/testing/analyze.test +++ b/testing/analyze.test @@ -5,28 +5,26 @@ source $testdir/tester.tcl # Things that do work: do_execsql_test_on_specific_db {:memory:} empty-table { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer); ANALYZE temp; SELECT * FROM sqlite_stat1; } {} do_execsql_test_on_specific_db {:memory:} one-row-table { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer); INSERT INTO temp VALUES (1); ANALYZE temp; SELECT * FROM sqlite_stat1; } {temp||1} -do_execsql_test_on_specific_db {:memory:} analyze-deletes { - CREATE TABLE sqlite_stat1(tbl,idx,stat); - INSERT INTO sqlite_stat1 VALUES ('temp', NULL, 10); +do_execsql_test_on_specific_db {:memory:} analyze-overwrites { CREATE TABLE temp (a integer); INSERT INTO temp VALUES (1); ANALYZE temp; + INSERT INTO temp VALUES (2); + ANALYZE temp; SELECT * FROM sqlite_stat1; -} {temp||1} +} {temp||2} # Things that don't work: @@ -38,25 +36,17 @@ do_execsql_test_in_memory_error analyze-one-database-fails { ANALYZE main; } {.*ANALYZE.*not supported.*} -do_execsql_test_in_memory_error analyze-without-stat-table-fails { - CREATE TABLE temp (a integer); - ANALYZE temp; -} {.*ANALYZE.*not supported.*} - do_execsql_test_in_memory_error analyze-table-with-pk-fails { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer primary key); ANALYZE temp; } {.*ANALYZE.*not supported.*} do_execsql_test_in_memory_error analyze-table-without-rowid-fails { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer primary key) WITHOUT ROWID; ANALYZE temp; } {.*ANALYZE.*not supported.*} do_execsql_test_in_memory_error analyze-index-fails { - CREATE TABLE sqlite_stat1(tbl,idx,stat); CREATE TABLE temp (a integer, b integer); CREATE INDEX temp_b ON temp (b); ANALYZE temp_b; From 328c5edf4d82f3cffd91dd8ffd530b8644f80a7c Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 02:17:53 +0530 Subject: [PATCH 18/73] Add `PRAGMA cipher` to allow setting cipher algo --- core/lib.rs | 12 ++++++++++++ core/pragma.rs | 4 ++++ core/storage/encryption.rs | 24 ++++++++++++++++++++++++ core/translate/pragma.rs | 17 ++++++++++++++++- parser/src/ast.rs | 4 ++++ 5 files changed, 60 insertions(+), 1 deletion(-) diff --git a/core/lib.rs b/core/lib.rs index e8e629887..3a8c7ead1 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -41,6 +41,7 @@ pub mod numeric; mod numeric; use crate::incremental::view::ViewTransactionState; +use crate::storage::encryption::CipherMode; use crate::translate::optimizer::optimize_plan; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; #[cfg(all(feature = "fs", feature = "conn_raw_api"))] @@ -455,6 +456,7 @@ impl Database { metrics: RefCell::new(ConnectionMetrics::new()), is_nested_stmt: Cell::new(false), encryption_key: RefCell::new(None), + encryption_cipher: RefCell::new(None), }); let builtin_syms = self.builtin_syms.borrow(); // add built-in extensions symbols to the connection to prevent having to load each time @@ -886,6 +888,7 @@ pub struct Connection { /// Generally this is only true for ParseSchema. is_nested_stmt: Cell, encryption_key: RefCell>, + encryption_cipher: RefCell>, } impl Connection { @@ -1961,6 +1964,15 @@ impl Connection { let pager = self.pager.borrow(); pager.set_encryption_context(&key); } + + pub fn set_encryption_cipher(&self, cipher: CipherMode) { + tracing::trace!("setting encryption cipher for connection"); + self.encryption_cipher.replace(Some(cipher)); + } + + pub fn get_encryption_cipher_mode(&self) -> Option { + self.encryption_cipher.borrow().clone() + } } pub struct Statement { diff --git a/core/pragma.rs b/core/pragma.rs index 0ae48b97a..e006963c0 100644 --- a/core/pragma.rs +++ b/core/pragma.rs @@ -111,6 +111,10 @@ pub fn pragma_for(pragma: &PragmaName) -> Pragma { PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, &["hexkey"], ), + EncryptionCipher => Pragma::new( + PragmaFlags::Result0 | PragmaFlags::SchemaReq | PragmaFlags::NoColumns1, + &["cipher"], + ), } } diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index e8b7f1181..17343f1c0 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -134,6 +134,30 @@ pub enum CipherMode { Aegis256, } +impl TryFrom<&str> for CipherMode { + type Error = LimboError; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "aes256gcm" | "aes-256-gcm" | "aes_256_gcm" => Ok(CipherMode::Aes256Gcm), + "aegis256" | "aegis-256" | "aegis_256" => Ok(CipherMode::Aegis256), + _ => Err(LimboError::InvalidArgument(format!( + "Unknown cipher name: {}", + s + ))), + } + } +} + +impl std::fmt::Display for CipherMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CipherMode::Aes256Gcm => write!(f, "aes256gcm"), + CipherMode::Aegis256 => write!(f, "aegis256"), + } + } +} + impl CipherMode { /// Every cipher requires a specific key size. For 256-bit algorithms, this is 32 bytes. /// For 128-bit algorithms, it would be 16 bytes, etc. diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index b53c7af5b..43abf02e7 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -10,7 +10,7 @@ use turso_parser::ast::{PragmaName, QualifiedName}; use super::integrity_check::translate_integrity_check; use crate::pragma::pragma_for; use crate::schema::Schema; -use crate::storage::encryption::EncryptionKey; +use crate::storage::encryption::{CipherMode, EncryptionKey}; use crate::storage::pager::AutoVacuumMode; use crate::storage::pager::Pager; use crate::storage::sqlite3_ondisk::CacheSize; @@ -318,6 +318,12 @@ fn update_pragma( connection.set_encryption_key(key); Ok((program, TransactionMode::None)) } + PragmaName::EncryptionCipher => { + let value = parse_string(&value)?; + let cipher = CipherMode::try_from(value.as_str())?; + connection.set_encryption_cipher(cipher); + Ok((program, TransactionMode::None)) + } } } @@ -589,6 +595,15 @@ fn query_pragma( program.add_pragma_result_column(pragma.to_string()); Ok((program, TransactionMode::None)) } + PragmaName::EncryptionCipher => { + if let Some(cipher) = connection.get_encryption_cipher_mode() { + let register = program.alloc_register(); + program.emit_string8(cipher.to_string(), register); + program.emit_result_row(register, 1); + program.add_pragma_result_column(pragma.to_string()); + } + Ok((program, TransactionMode::None)) + } } } diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 0c9de857c..a73e4ed36 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1211,6 +1211,10 @@ pub enum PragmaName { AutoVacuum, /// `cache_size` pragma CacheSize, + /// encryption cipher algorithm name for encrypted databases + #[strum(serialize = "cipher")] + #[cfg_attr(feature = "serde", serde(rename = "cipher"))] + EncryptionCipher, /// List databases DatabaseList, /// Encoding - only support utf8 From 48ce2a4a3ef4105ba24b8e01797412a36c371ed1 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 02:28:57 +0530 Subject: [PATCH 19/73] Set encryption ctx when cipher and key are set --- core/lib.rs | 18 ++++++++++++++++-- core/storage/pager.rs | 4 ++-- .../integration/query_processing/encryption.rs | 2 ++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 3a8c7ead1..46a479a54 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1961,18 +1961,32 @@ impl Connection { pub fn set_encryption_key(&self, key: EncryptionKey) { tracing::trace!("setting encryption key for connection"); *self.encryption_key.borrow_mut() = Some(key.clone()); - let pager = self.pager.borrow(); - pager.set_encryption_context(&key); + self.set_encryption_context(); } pub fn set_encryption_cipher(&self, cipher: CipherMode) { tracing::trace!("setting encryption cipher for connection"); self.encryption_cipher.replace(Some(cipher)); + self.set_encryption_context(); } pub fn get_encryption_cipher_mode(&self) -> Option { self.encryption_cipher.borrow().clone() } + + // if both key and cipher are set, set encryption context on pager + fn set_encryption_context(&self) { + let key_ref = self.encryption_key.borrow(); + let Some(key) = key_ref.as_ref() else { + return; + }; + let Some(cipher) = self.encryption_cipher.borrow().clone() else { + return; + }; + tracing::trace!("setting encryption ctx for connection"); + let pager = self.pager.borrow(); + pager.set_encryption_context(cipher, key); + } } pub struct Statement { diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 85f33775e..c22a09c4d 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -2109,8 +2109,8 @@ impl Pager { Ok(IOResult::Done(f(header))) } - pub fn set_encryption_context(&self, key: &EncryptionKey) { - let encryption_ctx = EncryptionContext::new(CipherMode::Aegis256, key).unwrap(); + pub fn set_encryption_context(&self, cipher_mode: CipherMode, key: &EncryptionKey) { + let encryption_ctx = EncryptionContext::new(cipher_mode, key).unwrap(); self.encryption_ctx.replace(Some(encryption_ctx.clone())); let Some(wal) = self.wal.as_ref() else { return }; wal.borrow_mut().set_encryption_context(encryption_ctx) diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index 82bef1ef1..3cca08587 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -18,6 +18,7 @@ fn test_per_page_encryption() -> anyhow::Result<()> { &conn, "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", )?; + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';")?; run_query( &tmp_db, &conn, @@ -55,6 +56,7 @@ fn test_per_page_encryption() -> anyhow::Result<()> { // let's test the existing db with the key let existing_db = TempDatabase::new_with_existent(&db_path, false); let conn = existing_db.connect_limbo(); + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';")?; run_query( &existing_db, &conn, From 84d20ba60fdbf624f0edd8756464f96fcb3eedcd Mon Sep 17 00:00:00 2001 From: rajajisai Date: Sun, 24 Aug 2025 18:45:46 -0400 Subject: [PATCH 20/73] Use F_FULLSYNC in darwin based operating systems --- core/io/unix.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/core/io/unix.rs b/core/io/unix.rs index e10f7f3ec..5877fa194 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -259,12 +259,24 @@ impl File for UnixFile { #[instrument(err, skip_all, level = Level::TRACE)] fn sync(&self, c: Completion) -> Result { let file = self.file.lock(); - let result = unsafe { libc::fsync(file.as_raw_fd()) }; + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + let result = libc::fsync(file.as_raw_fd()); + + #[cfg(any(target_os = "macos", target_os = "ios"))] + let result = unsafe { libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC) }; + if result == -1 { let e = std::io::Error::last_os_error(); Err(e.into()) } else { + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] trace!("fsync"); + + #[cfg(any(target_os = "macos", target_os = "ios"))] + trace!("fcntl(F_FULLSYNC)"); + c.complete(0); Ok(c) } From 9068a2938057433f39f93f9ac6f6ac4a7d75dbac Mon Sep 17 00:00:00 2001 From: rajajisai Date: Sun, 24 Aug 2025 18:56:05 -0400 Subject: [PATCH 21/73] Use unsafe block --- core/io/unix.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/core/io/unix.rs b/core/io/unix.rs index 5877fa194..b7567b683 100644 --- a/core/io/unix.rs +++ b/core/io/unix.rs @@ -260,11 +260,19 @@ impl File for UnixFile { fn sync(&self, c: Completion) -> Result { let file = self.file.lock(); - #[cfg(not(any(target_os = "macos", target_os = "ios")))] - let result = libc::fsync(file.as_raw_fd()); - - #[cfg(any(target_os = "macos", target_os = "ios"))] - let result = unsafe { libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC) }; + let result = unsafe { + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + { + libc::fsync(file.as_raw_fd()) + } + + #[cfg(any(target_os = "macos", target_os = "ios"))] + { + libc::fcntl(file.as_raw_fd(), libc::F_FULLFSYNC) + } + + }; if result == -1 { let e = std::io::Error::last_os_error(); From 54ff656c9d9fb612a06944b11cbd55df331e2f62 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Mon, 25 Aug 2025 08:49:22 +0300 Subject: [PATCH 22/73] Do not clear txn state inside nested statement If a connection does e.g. CREATE TABLE, it will start a "child statement" to reparse the schema. That statement does not start its own transaction, and so should not try to end the existing one either. We had a logic bug where these steps would happen: - `CREATE TABLE` executed successfully - pread fault happens inside `ParseSchema` child stmt - `handle_program_error()` is called - `pager.end_tx()` returns immediately because `is_nested_stmt` is true and we correctly no-op it. - however, crucially: `handle_program_error()` then sets tx state to None - parent statement now catches error from nested stmt and calls `handle_program_error()`, which calls `pager.end_tx()` again, and since txn state is None, when it calls `rollback()` we panic on the assertion `"dirty pages should be empty for read txn"` Solution: Do not do _any_ error processing in `handle_program_error()` inside a nested stmt. This means that the parent write txn is still active when it processes the error from the child and we avoid this panic. --- core/vdbe/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 5a9cbe646..a12e5772e 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -857,6 +857,10 @@ pub fn handle_program_error( connection: &Connection, err: &LimboError, ) -> Result<()> { + if connection.is_nested_stmt.get() { + // Errors from nested statements are handled by the parent statement. + return Ok(()); + } match err { // Transaction errors, e.g. trying to start a nested transaction, do not cause a rollback. LimboError::TxError(_) => {} From 4ea8cd0007a8456fd3d2fddbc762822e4a7f992b Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Fri, 8 Aug 2025 14:03:29 +0300 Subject: [PATCH 23/73] refactor/btree: rewrite the free_cell_range() function i had a rough time reading this function earlier and trying to understand it, so rewrote it in a way that, to me, is much more readable. --- core/storage/btree.rs | 239 ++++++++++++++++++++------------- core/storage/sqlite3_ondisk.rs | 10 ++ 2 files changed, 157 insertions(+), 92 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index ca5624220..15e43d25c 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6384,121 +6384,176 @@ fn page_insert_array( /// This function also updates the freeblock list in the page. /// Freeblocks are used to keep track of free space in the page, /// and are organized as a linked list. +/// +/// This function may merge the freed cell range into either the next freeblock, +/// previous freeblock, or both. fn free_cell_range( page: &mut PageContent, mut offset: usize, len: usize, usable_space: usize, ) -> Result<()> { - if len < 4 { - return_corrupt!("Minimum cell size is 4"); + const CELL_SIZE_MIN: usize = 4; + if len < CELL_SIZE_MIN { + return_corrupt!("free_cell_range: minimum cell size is {CELL_SIZE_MIN}"); } - - if offset > usable_space.saturating_sub(4) { - return_corrupt!("Start offset beyond usable space"); + if offset > usable_space.saturating_sub(CELL_SIZE_MIN) { + return_corrupt!("free_cell_range: start offset beyond usable space: offset={offset} usable_space={usable_space}"); } let mut size = len; let mut end = offset + len; - let mut pointer_to_pc = page.offset + 1; - // if the freeblock list is empty, we set this block as the first freeblock in the page header. - let pc = if page.first_freeblock() == 0 { - 0 - } else { - // if the freeblock list is not empty, and the offset is greater than the first freeblock, - // then we need to do some more calculation to figure out where to insert the freeblock - // in the freeblock linked list. - let first_block = page.first_freeblock() as usize; - - let mut pc = first_block; - - while pc < offset { - if pc <= pointer_to_pc { - if pc == 0 { - break; - } - return_corrupt!("free cell range free block not in ascending order"); - } - - let next = page.read_u16_no_offset(pc) as usize; - pointer_to_pc = pc; - pc = next; + let cur_content_area = page.cell_content_area() as usize; + let first_block = page.first_freeblock() as usize; + if first_block == 0 { + if offset < cur_content_area { + return_corrupt!("free_cell_range: free block before content area: offset={offset} cell_content_area={cur_content_area}"); } - - if pc > usable_space - 4 { - return_corrupt!("Free block beyond usable space"); + if offset == cur_content_area { + // if the freeblock list is empty and the freed range is exactly at the beginning of the content area, + // we are not creating a freeblock; instead we are just extending the unallocated region. + page.write_cell_content_area(end); + } else { + // otherwise we set it as the first freeblock in the page header. + let offset_u16: u16 = offset + .try_into() + .unwrap_or_else(|_| panic!("offset={offset} is too large to fit in a u16")); + page.write_first_freeblock(offset_u16); + let size_u16: u16 = size + .try_into() + .unwrap_or_else(|_| panic!("size={size} is too large to fit in a u16")); + page.write_freeblock(offset_u16, size_u16, None); } - let mut removed_fragmentation = 0; - if pc > 0 && offset + len + 3 >= pc { - removed_fragmentation = (pc - end) as u8; + return Ok(()); + } - if end > pc { - return_corrupt!("Invalid block overlap"); - } - end = pc + page.read_u16_no_offset(pc + 2) as usize; + // if the freeblock list is not empty, we need to find the correct position to insert the new freeblock + // resulting from the freeing of this cell range; we may be also able to merge the freed range into existing freeblocks. + let mut prev_block = None; + let mut next_block = Some(first_block); + + while let Some(next) = next_block { + if prev_block.is_some_and(|prev| next <= prev) { + return_corrupt!("free_cell_range: freeblocks not in ascending order: next_block={next} prev_block={prev_block:?}"); + } + if next >= offset { + break; + } + prev_block = Some(next); + next_block = match page.read_u16_no_offset(next) { + // Freed range extends beyond the last freeblock, so we are creating a new freeblock. + 0 => None, + next => Some(next as usize), + }; + } + + if let Some(next) = next_block { + if next + CELL_SIZE_MIN > usable_space { + return_corrupt!("free_cell_range: free block beyond usable space: next_block={next} usable_space={usable_space}"); + } + } + let mut removed_fragmentation = 0; + const SINGLE_FRAGMENT_SIZE_MAX: usize = CELL_SIZE_MIN - 1; + + if end > usable_space { + return_corrupt!("free_cell_range: freed range extends beyond usable space: offset={offset} len={len} end={end} usable_space={usable_space}"); + } + + // If the freed range extends into the next freeblock, we will merge the freed range into it. + // If there is a 1-3 byte gap between the freed range and the next freeblock, we are effectively + // clearing that amount of fragmented bytes, since a 1-3 byte range cannot be a valid cell. + if let Some(next) = next_block { + if end + SINGLE_FRAGMENT_SIZE_MAX >= next { + removed_fragmentation = (next - end) as u8; + let next_size = page.read_u16_no_offset(next + 2) as usize; + end = next + next_size; if end > usable_space { - return_corrupt!("Coalesced block extends beyond page"); + return_corrupt!("free_cell_range: coalesced block extends beyond page: offset={offset} len={len} end={end} usable_space={usable_space}"); } size = end - offset; - pc = page.read_u16_no_offset(pc) as usize; + // Since we merged the two freeblocks, we need to update the next_block to the next freeblock in the list. + next_block = match page.read_u16_no_offset(next) { + 0 => None, + next => Some(next as usize), + }; } + } - if pointer_to_pc > page.offset + 1 { - let prev_end = pointer_to_pc + page.read_u16_no_offset(pointer_to_pc + 2) as usize; - if prev_end + 3 >= offset { - if prev_end > offset { - return_corrupt!("Invalid previous block overlap"); + // If the freed range extends into the previous freeblock, we will merge them similarly as above. + if let Some(prev) = prev_block { + let prev_size = page.read_u16_no_offset(prev + 2) as usize; + let prev_end = prev + prev_size; + if prev_end > offset { + return_corrupt!( + "free_cell_range: previous block overlap: prev_end={prev_end} offset={offset}" + ); + } + // If the previous freeblock extends into the freed range, we will merge the freed range into the + // previous freeblock and clear any 1-3 byte fragmentation in between, similarly as above + if prev_end + SINGLE_FRAGMENT_SIZE_MAX >= offset { + removed_fragmentation += (offset - prev_end) as u8; + size = end - prev; + offset = prev; + } + } + + let cur_frag_free_bytes = page.num_frag_free_bytes(); + if removed_fragmentation > cur_frag_free_bytes { + return_corrupt!("free_cell_range: invalid fragmentation count: removed_fragmentation={removed_fragmentation} num_frag_free_bytes={cur_frag_free_bytes}"); + } + let frag = cur_frag_free_bytes - removed_fragmentation; + page.write_fragmented_bytes_count(frag); + + if offset < cur_content_area { + return_corrupt!("free_cell_range: free block before content area: offset={offset} cell_content_area={cur_content_area}"); + } + + // As above, if the freed range is exactly at the beginning of the content area, we are not creating a freeblock; + // instead we are just extending the unallocated region. + if offset == cur_content_area { + if prev_block.is_some_and(|prev| prev != first_block) { + return_corrupt!("free_cell_range: invalid content area merge - freed range should have been merged with previous freeblock: prev={prev} first_block={first_block}"); + } + // If we get here, we are freeing data from the left end of the content area, + // so we are extending the unallocated region instead of creating a freeblock. + // We update the first freeblock to be the next one, and shrink the content area to start from the end + // of the freed range. + match next_block { + Some(next) => { + if next <= end { + return_corrupt!("free_cell_range: invalid content area merge - first freeblock should either be 0 or greater than the content area start: next_block={next} end={end}"); } - removed_fragmentation += (offset - prev_end) as u8; - size = end - pointer_to_pc; - offset = pointer_to_pc; + let next_u16: u16 = next + .try_into() + .unwrap_or_else(|_| panic!("next={next} is too large to fit in a u16")); + page.write_first_freeblock(next_u16); + } + None => { + page.write_first_freeblock(0); } } - if removed_fragmentation > page.num_frag_free_bytes() { - return_corrupt!(format!( - "Invalid fragmentation count. Had {} and removed {}", - page.num_frag_free_bytes(), - removed_fragmentation - )); - } - let frag = page.num_frag_free_bytes() - removed_fragmentation; - page.write_fragmented_bytes_count(frag); - pc - }; - - if (offset as u32) <= page.cell_content_area() { - if (offset as u32) < page.cell_content_area() { - return_corrupt!("Free block before content area"); - } - if pointer_to_pc != page.offset + offset::BTREE_FIRST_FREEBLOCK { - return_corrupt!("Invalid content area merge"); - } - turso_assert!( - pc < PageSize::MAX as usize, - "pc={pc} PageSize::MAX={}", - PageSize::MAX - ); - page.write_first_freeblock(pc as u16); page.write_cell_content_area(end); } else { - turso_assert!( - pointer_to_pc < PageSize::MAX as usize, - "pointer_to_pc={pointer_to_pc} PageSize::MAX={}", - PageSize::MAX - ); - turso_assert!( - offset < PageSize::MAX as usize, - "offset={offset} PageSize::MAX={}", - PageSize::MAX - ); - turso_assert!( - size < PageSize::MAX as usize, - "size={size} PageSize::MAX={}", - PageSize::MAX - ); - page.write_u16_no_offset(pointer_to_pc, offset as u16); - page.write_u16_no_offset(offset, pc as u16); - page.write_u16_no_offset(offset + 2, size as u16); + // If we are creating a new freeblock: + // a) if it's the first one, we update the header to indicate so, + // b) if it's not the first one, we update the previous freeblock to point to the new one, + // and the new one to point to the next one. + let offset_u16: u16 = offset + .try_into() + .unwrap_or_else(|_| panic!("offset={offset} is too large to fit in a u16")); + if let Some(prev) = prev_block { + page.write_u16_no_offset(prev, offset_u16); + } else { + page.write_first_freeblock(offset_u16); + } + let size_u16: u16 = size + .try_into() + .unwrap_or_else(|_| panic!("size={size} is too large to fit in a u16")); + let next_block_u16 = next_block.map(|b| { + b.try_into() + .unwrap_or_else(|_| panic!("next_block={b} is too large to fit in a u16")) + }); + page.write_freeblock(offset_u16, size_u16, next_block_u16); } Ok(()) @@ -6893,7 +6948,7 @@ fn compute_free_space(page: &PageContent, usable_space: usize) -> usize { // Next should always be 0 (NULL) at this point since we have reached the end of the freeblocks linked list assert_eq!( next, 0, - "corrupted page: freeblocks list not in ascending order" + "corrupted page: freeblocks list not in ascending order: cur_freeblock_ptr={cur_freeblock_ptr} size={size} next={next}" ); assert!( diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 8ec4c861f..da99edc59 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -568,6 +568,16 @@ impl PageContent { self.write_u16(BTREE_FIRST_FREEBLOCK, value); } + /// Write a freeblock to the page content at the given absolute offset. + /// Parameters: + /// - offset: the absolute offset of the freeblock + /// - size: the size of the freeblock + /// - next_block: the absolute offset of the next freeblock, or None if this is the last freeblock + pub fn write_freeblock(&self, offset: u16, size: u16, next_block: Option) { + self.write_u16_no_offset(offset as usize, next_block.unwrap_or(0)); + self.write_u16_no_offset(offset as usize + 2, size); + } + /// Write the number of cells on this page. pub fn write_cell_count(&self, value: u16) { self.write_u16(BTREE_CELL_COUNT, value); From dc6bcd4d41c01e0df33d059e8853215fd79440ad Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Fri, 8 Aug 2025 15:08:39 +0300 Subject: [PATCH 24/73] refactor/btree: rewrite find_free_cell() --- core/storage/btree.rs | 128 ++++++++++++++++++++++----------- core/storage/sqlite3_ondisk.rs | 27 ++++++- 2 files changed, 112 insertions(+), 43 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 15e43d25c..855641b08 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6081,59 +6081,104 @@ impl BTreePageInner { } } -/// Try to find a free block available and allocate it if found -fn find_free_cell(page_ref: &PageContent, usable_space: usize, amount: usize) -> Result { +/// Try to find a freeblock inside the cell content area that is large enough to fit the given amount of bytes. +/// Used to check if a cell can be inserted into a freeblock to reduce fragmentation. +/// Returns the absolute byte offset of the freeblock if found. +fn find_free_slot( + page_ref: &PageContent, + usable_space: usize, + amount: usize, +) -> Result> { + const CELL_SIZE_MIN: usize = 4; // NOTE: freelist is in ascending order of keys and pc // unuse_space is reserved bytes at the end of page, therefore we must substract from maxpc - let mut prev_pc = page_ref.offset + offset::BTREE_FIRST_FREEBLOCK; - let mut pc = page_ref.first_freeblock() as usize; - let maxpc = usable_space - amount; + let mut prev_block = None; + let mut cur_block = match page_ref.first_freeblock() { + 0 => None, + first_block => Some(first_block as usize), + }; - while pc <= maxpc { - if pc + 4 > usable_space { + let max_start_offset = usable_space - amount; + + while let Some(cur) = cur_block { + if cur + CELL_SIZE_MIN > usable_space { return_corrupt!("Free block header extends beyond page"); } - let next = page_ref.read_u16_no_offset(pc); - let size = page_ref.read_u16_no_offset(pc + 2); + let (next, size) = { + let cur_u16: u16 = cur + .try_into() + .unwrap_or_else(|_| panic!("cur={cur} is too large to fit in a u16")); + let (next, size) = page_ref.read_freeblock(cur_u16); + (next as usize, size as usize) + }; - if amount <= size as usize { - let new_size = size as usize - amount; - if new_size < 4 { - // The code is checking if using a free slot that would leave behind a very small fragment (x < 4 bytes) - // would cause the total fragmentation to exceed the limit of 60 bytes - // check sqlite docs https://www.sqlite.org/fileformat.html#:~:text=A%20freeblock%20requires,not%20exceed%2060 - if page_ref.num_frag_free_bytes() > 57 { - return Ok(0); - } - // Delete the slot from freelist and update the page's fragment count. - page_ref.write_u16_no_offset(prev_pc, next); - let frag = page_ref.num_frag_free_bytes() + new_size as u8; - page_ref.write_fragmented_bytes_count(frag); - return Ok(pc); - } else if new_size + pc > maxpc { - return_corrupt!("Free block extends beyond page end"); - } else { - // Requested amount fits inside the current free slot so we reduce its size - // to account for newly allocated space. - page_ref.write_u16_no_offset(pc + 2, new_size as u16); - return Ok(pc + new_size); + // Doesn't fit in this freeblock, try the next one. + if amount > size { + if next == 0 { + // No next -> can't fit. + return Ok(None); } - } - prev_pc = pc; - pc = next as usize; - if pc <= prev_pc { - if pc != 0 { + prev_block = cur_block; + if next <= cur { return_corrupt!("Free list not in ascending order"); } - return Ok(0); + cur_block = Some(next); + continue; + } + + let new_size = size - amount; + // If the freeblock's new size is < CELL_SIZE_MIN, the freeblock is deleted and the remaining bytes + // become fragmented free bytes. + if new_size < CELL_SIZE_MIN { + if page_ref.num_frag_free_bytes() > 57 { + // SQLite has a fragmentation limit of 60 bytes. + // check sqlite docs https://www.sqlite.org/fileformat.html#:~:text=A%20freeblock%20requires,not%20exceed%2060 + return Ok(None); + } + // Delete the slot from freelist and update the page's fragment count. + match prev_block { + Some(prev) => { + let prev_u16: u16 = prev + .try_into() + .unwrap_or_else(|_| panic!("prev={prev} is too large to fit in a u16")); + let next_u16: u16 = next + .try_into() + .unwrap_or_else(|_| panic!("next={next} is too large to fit in a u16")); + page_ref.write_freeblock_next_ptr(prev_u16, next_u16); + } + None => { + let next_u16: u16 = next + .try_into() + .unwrap_or_else(|_| panic!("next={next} is too large to fit in a u16")); + page_ref.write_first_freeblock(next_u16); + } + } + let new_size_u8: u8 = new_size + .try_into() + .unwrap_or_else(|_| panic!("new_size={new_size} is too large to fit in a u8")); + let frag = page_ref.num_frag_free_bytes() + new_size_u8; + page_ref.write_fragmented_bytes_count(frag); + return Ok(cur_block); + } else if new_size + cur > max_start_offset { + return_corrupt!("Free block extends beyond page end"); + } else { + // Requested amount fits inside the current free slot so we reduce its size + // to account for newly allocated space. + let cur_u16: u16 = cur + .try_into() + .unwrap_or_else(|_| panic!("cur={cur} is too large to fit in a u16")); + let new_size_u16: u16 = new_size + .try_into() + .unwrap_or_else(|_| panic!("new_size={new_size} is too large to fit in a u16")); + page_ref.write_freeblock_size(cur_u16, new_size_u16); + // Return the offset immediately after the shrunk freeblock. + return Ok(Some(cur + new_size)); } } - if pc > maxpc + amount - 4 { - return_corrupt!("Free block chain extends beyond page end"); - } - Ok(0) + + Ok(None) } pub fn btree_init_page(page: &BTreePage, page_type: PageType, offset: usize, usable_space: usize) { @@ -6984,8 +7029,7 @@ fn allocate_cell_space( && unallocated_region_start + CELL_PTR_SIZE_BYTES <= cell_content_area_start { // find slot - let pc = find_free_cell(page_ref, usable_space, amount)?; - if pc != 0 { + if let Some(pc) = find_free_slot(page_ref, usable_space, amount)? { // we can fit the cell in a freeblock. return Ok(pc as u16); } diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index da99edc59..53ebf04a4 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -574,10 +574,35 @@ impl PageContent { /// - size: the size of the freeblock /// - next_block: the absolute offset of the next freeblock, or None if this is the last freeblock pub fn write_freeblock(&self, offset: u16, size: u16, next_block: Option) { - self.write_u16_no_offset(offset as usize, next_block.unwrap_or(0)); + self.write_freeblock_next_ptr(offset, next_block.unwrap_or(0)); + self.write_freeblock_size(offset, size); + } + + /// Write the new size of a freeblock. + /// Parameters: + /// - offset: the absolute offset of the freeblock + /// - size: the new size of the freeblock + pub fn write_freeblock_size(&self, offset: u16, size: u16) { self.write_u16_no_offset(offset as usize + 2, size); } + /// Write the absolute offset of the next freeblock. + /// Parameters: + /// - offset: the absolute offset of the current freeblock + /// - next_block: the absolute offset of the next freeblock + pub fn write_freeblock_next_ptr(&self, offset: u16, next_block: u16) { + self.write_u16_no_offset(offset as usize, next_block); + } + + /// Read a freeblock from the page content at the given absolute offset. + /// Returns (absolute offset of next freeblock, size of the current freeblock) + pub fn read_freeblock(&self, offset: u16) -> (u16, u16) { + ( + self.read_u16_no_offset(offset as usize), + self.read_u16_no_offset(offset as usize + 2), + ) + } + /// Write the number of cells on this page. pub fn write_cell_count(&self, value: u16) { self.write_u16(BTREE_CELL_COUNT, value); From b162f89b73c781d321712b03e0595caea1ff98a7 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Aug 2025 07:36:12 +0300 Subject: [PATCH 25/73] sqlite3: Implement sqlite3_db_filename() --- sqlite3/include/sqlite3.h | 2 ++ sqlite3/src/lib.rs | 38 +++++++++++++++++++++++++++++------ sqlite3/tests/compat/mod.rs | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 6 deletions(-) diff --git a/sqlite3/include/sqlite3.h b/sqlite3/include/sqlite3.h index 68a8944e5..f56f38b8c 100644 --- a/sqlite3/include/sqlite3.h +++ b/sqlite3/include/sqlite3.h @@ -76,6 +76,8 @@ int sqlite3_close(sqlite3 *db); int sqlite3_close_v2(sqlite3 *db); +const char *sqlite3_db_filename(sqlite3 *db, const char *db_name); + int sqlite3_trace_v2(sqlite3 *_db, unsigned int _mask, void (*_callback)(unsigned int, void*, void*, void*), diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index b7bd6529a..46d6d64b3 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -56,6 +56,7 @@ struct sqlite3Inner { pub(crate) malloc_failed: bool, pub(crate) e_open_state: u8, pub(crate) p_err: *mut ffi::c_void, + pub(crate) filename: CString, } impl sqlite3 { @@ -63,6 +64,7 @@ impl sqlite3 { io: Arc, db: Arc, conn: Arc, + filename: CString, ) -> Self { let inner = sqlite3Inner { _io: io, @@ -73,6 +75,7 @@ impl sqlite3 { malloc_failed: false, e_open_state: SQLITE_STATE_OPEN, p_err: std::ptr::null_mut(), + filename, }; #[allow(clippy::arc_with_non_send_sync)] let inner = Arc::new(Mutex::new(inner)); @@ -132,26 +135,30 @@ pub unsafe extern "C" fn sqlite3_open( if db_out.is_null() { return SQLITE_MISUSE; } - let filename = CStr::from_ptr(filename); - let filename = match filename.to_str() { + let filename_cstr = CStr::from_ptr(filename); + let filename_str = match filename_cstr.to_str() { Ok(s) => s, Err(_) => return SQLITE_MISUSE, }; - let io: Arc = match filename { + let io: Arc = match filename_str { ":memory:" => Arc::new(turso_core::MemoryIO::new()), _ => match turso_core::PlatformIO::new() { Ok(io) => Arc::new(io), Err(_) => return SQLITE_CANTOPEN, }, }; - match turso_core::Database::open_file(io.clone(), filename, false, false) { + match turso_core::Database::open_file(io.clone(), filename_str, false, false) { Ok(db) => { let conn = db.connect().unwrap(); - *db_out = Box::leak(Box::new(sqlite3::new(io, db, conn))); + let filename = match filename_str { + ":memory:" => CString::new("".to_string()).unwrap(), + _ => CString::from(filename_cstr), + }; + *db_out = Box::leak(Box::new(sqlite3::new(io, db, conn, filename))); SQLITE_OK } Err(e) => { - trace!("error opening database {}: {:?}", filename, e); + trace!("error opening database {}: {:?}", filename_str, e); SQLITE_CANTOPEN } } @@ -184,6 +191,25 @@ pub unsafe extern "C" fn sqlite3_close_v2(db: *mut sqlite3) -> ffi::c_int { sqlite3_close(db) } +#[no_mangle] +pub unsafe extern "C" fn sqlite3_db_filename( + db: *mut sqlite3, + db_name: *const ffi::c_char, +) -> *const ffi::c_char { + if db.is_null() { + return std::ptr::null(); + } + if !db_name.is_null() { + let name = CStr::from_ptr(db_name); + if name.to_bytes() != b"main" { + return std::ptr::null(); + } + } + let db = &*db; + let inner = db.inner.lock().unwrap(); + inner.filename.as_ptr() +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_trace_v2( _db: *mut sqlite3, diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index a40ae7538..94361cf11 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -19,6 +19,7 @@ extern "C" { fn sqlite3_libversion_number() -> i32; fn sqlite3_close(db: *mut sqlite3) -> i32; fn sqlite3_open(filename: *const libc::c_char, db: *mut *mut sqlite3) -> i32; + fn sqlite3_db_filename(db: *mut sqlite3, db_name: *const libc::c_char) -> *const libc::c_char; fn sqlite3_prepare_v2( db: *mut sqlite3, sql: *const libc::c_char, @@ -1279,4 +1280,43 @@ mod tests { assert_eq!(sqlite3_finalize(stmt), SQLITE_OK); } } + + #[test] + fn test_sqlite3_db_filename() { + const SQLITE_OK: i32 = 0; + + unsafe { + // Test with in-memory database + let mut db: *mut sqlite3 = ptr::null_mut(); + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + let filename = sqlite3_db_filename(db, c"main".as_ptr()); + assert!(!filename.is_null()); + let filename_str = std::ffi::CStr::from_ptr(filename).to_str().unwrap(); + assert_eq!(filename_str, ""); + assert_eq!(sqlite3_close(db), SQLITE_OK); + + // Open a file-backed database + let temp_file = tempfile::NamedTempFile::with_suffix(".db").unwrap(); + let path = std::ffi::CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let mut db = ptr::null_mut(); + assert_eq!(sqlite3_open(path.as_ptr(), &mut db), SQLITE_OK); + + // Test with "main" database name + let filename = sqlite3_db_filename(db, c"main".as_ptr()); + assert!(!filename.is_null()); + let filename_str = std::ffi::CStr::from_ptr(filename).to_str().unwrap(); + assert_eq!(filename_str, temp_file.path().to_str().unwrap()); + + // Test with NULL database name (defaults to main) + let filename_default = sqlite3_db_filename(db, ptr::null()); + assert!(!filename_default.is_null()); + assert_eq!(filename, filename_default); + + // Test with non-existent database name + let filename = sqlite3_db_filename(db, c"temp".as_ptr()); + assert!(filename.is_null()); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } } From c62b87d9b6b0583bfd5435110c0d340cd5219dcb Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 22 Aug 2025 17:16:39 +0400 Subject: [PATCH 26/73] read from database file only if max_frame_read_lock_index is 0 and max_frame > min_frame - transaction which was started with max_frame = 0 and max_frame_read_lock_index = 0 can write to the WAL and in this case it needs to read data back from WAL - without cache spilling its hard to reproduce this issue for the turso-db now, but I stumbled into this issue with sync-engine which do weird stuff with the WAL which "simulates" cache spilling behaviour to some extent --- core/storage/pager.rs | 33 ++++++------ core/storage/wal.rs | 35 ++++++++++--- tests/integration/functions/test_wal_api.rs | 50 +++++++++++++++++++ .../query_processing/test_write_path.rs | 3 +- 4 files changed, 99 insertions(+), 22 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index c718485c1..9a98d30d7 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1416,22 +1416,25 @@ impl Pager { "wal_insert_frame() called on database without WAL".to_string(), )); }; - let mut wal = wal.borrow_mut(); let (header, raw_page) = parse_wal_frame_header(frame); - wal.write_frame_raw( - self.buffer_pool.clone(), - frame_no, - header.page_number as u64, - header.db_size as u64, - raw_page, - )?; - if let Some(page) = self.cache_get(header.page_number as usize) { - let content = page.get_contents(); - content.as_ptr().copy_from_slice(raw_page); - turso_assert!( - page.get().id == header.page_number as usize, - "page has unexpected id" - ); + + { + let mut wal = wal.borrow_mut(); + wal.write_frame_raw( + self.buffer_pool.clone(), + frame_no, + header.page_number as u64, + header.db_size as u64, + raw_page, + )?; + if let Some(page) = self.cache_get(header.page_number as usize) { + let content = page.get_contents(); + content.as_ptr().copy_from_slice(raw_page); + turso_assert!( + page.get().id == header.page_number as usize, + "page has unexpected id" + ); + } } if header.page_number == 1 { let db_size = self diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 35882bf7d..a875af365 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -815,14 +815,14 @@ impl Wal for WalFile { // WAL and fetch pages directly from the DB file. We do this // by taking read‑lock 0, and capturing the latest state. if shared_max == nbackfills { - let lock_idx = 0; - if !self.get_shared().read_locks[lock_idx].read() { + let lock_0_idx = 0; + if !self.get_shared().read_locks[lock_0_idx].read() { return Ok((LimboResult::Busy, db_changed)); } // we need to keep self.max_frame set to the appropriate // max frame in the wal at the time this transaction starts. self.max_frame = shared_max; - self.max_frame_read_lock_index.set(lock_idx); + self.max_frame_read_lock_index.set(lock_0_idx); self.min_frame = nbackfills + 1; self.last_checksum = last_checksum; return Ok((LimboResult::Ok, db_changed)); @@ -965,7 +965,7 @@ impl Wal for WalFile { } // Snapshot is stale, give up and let caller retry from scratch - tracing::debug!("unable to upgrade transaction from read to write: snapshot is stale, give up and let caller retry from scratch"); + tracing::debug!("unable to upgrade transaction from read to write: snapshot is stale, give up and let caller retry from scratch, self.max_frame={}, shared_max={}", self.max_frame, shared_max); shared.write_lock.unlock(); Ok(LimboResult::Busy) } @@ -1000,8 +1000,18 @@ impl Wal for WalFile { "frame_watermark must be >= than current WAL backfill amount: frame_watermark={:?}, nBackfill={}", frame_watermark, self.get_shared().nbackfills.load(Ordering::Acquire) ); - // if we are holding read_lock 0, skip and read right from db file. - if self.max_frame_read_lock_index.get() == 0 { + // if we are holding read_lock 0 and didn't write anything to the WAL, skip and read right from db file. + // + // note, that max_frame_read_lock_index is set to 0 only when shared_max_frame == nbackfill in which case + // min_frame is set to nbackfill + 1 and max_frame is set to shared_max_frame + // + // by default, SQLite tries to restart log file in this case - but for now let's keep it simple in the turso-db + if self.max_frame_read_lock_index.get() == 0 && self.max_frame < self.min_frame { + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): max_frame is 0 - read from DB file", + page_id, + frame_watermark, + ); return Ok(None); } let shared = self.get_shared(); @@ -1009,8 +1019,21 @@ impl Wal for WalFile { let range = frame_watermark .map(|x| 0..=x) .unwrap_or(self.min_frame..=self.max_frame); + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): min_frame={}, max_frame={}", + page_id, + frame_watermark, + self.min_frame, + self.max_frame + ); if let Some(list) = frames.get(&page_id) { if let Some(f) = list.iter().rfind(|&&f| range.contains(&f)) { + tracing::debug!( + "find_frame(page_id={}, frame_watermark={:?}): found frame={}", + page_id, + frame_watermark, + *f + ); return Ok(Some(*f)); } } diff --git a/tests/integration/functions/test_wal_api.rs b/tests/integration/functions/test_wal_api.rs index 38537000a..069b0b084 100644 --- a/tests/integration/functions/test_wal_api.rs +++ b/tests/integration/functions/test_wal_api.rs @@ -926,3 +926,53 @@ fn test_db_share_same_file() { ]] ); } + +#[test] +fn test_wal_api_simulate_spilled_frames() { + let (mut rng, _) = rng_from_time(); + let db1 = TempDatabase::new_empty(false); + let conn1 = db1.connect_limbo(); + let db2 = TempDatabase::new_empty(false); + let conn2 = db2.connect_limbo(); + conn1 + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY, y)") + .unwrap(); + conn2 + .execute("CREATE TABLE t(x INTEGER PRIMARY KEY, y)") + .unwrap(); + let watermark = conn1.wal_state().unwrap().max_frame; + for _ in 0..128 { + let key = rng.next_u32(); + let length = rng.next_u32() % 4096 + 1; + conn1 + .execute(format!( + "INSERT INTO t VALUES ({key}, randomblob({length}))" + )) + .unwrap(); + } + let mut frame = [0u8; 24 + 4096]; + conn2 + .checkpoint(CheckpointMode::Truncate { + upper_bound_inclusive: None, + }) + .unwrap(); + conn2.wal_insert_begin().unwrap(); + let frames_count = conn1.wal_state().unwrap().max_frame; + for frame_id in watermark + 1..=frames_count { + let mut info = conn1.wal_get_frame(frame_id, &mut frame).unwrap(); + info.db_size = 0; + info.put_to_frame_header(&mut frame); + conn2 + .wal_insert_frame(frame_id - watermark, &frame) + .unwrap(); + } + for _ in 0..128 { + let key = rng.next_u32(); + let length = rng.next_u32() % 4096 + 1; + conn2 + .execute(format!( + "INSERT INTO t VALUES ({key}, randomblob({length}))" + )) + .unwrap(); + } +} diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index e2cbfa06e..096327db1 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -284,6 +284,7 @@ fn test_wal_checkpoint() -> anyhow::Result<()> { let conn = tmp_db.connect_limbo(); for i in 0..iterations { + log::info!("iteration #{}", i); let insert_query = format!("INSERT INTO test VALUES ({i})"); do_flush(&conn, &tmp_db)?; conn.checkpoint(CheckpointMode::Passive { @@ -823,7 +824,7 @@ pub fn run_query_core( on_row(row) } } - _ => unreachable!(), + r => panic!("unexpected step result: {r:?}"), } } }; From 74104036916563458e557247577a45fe48e6607a Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 25 Aug 2025 11:38:59 +0400 Subject: [PATCH 27/73] fix clippy --- tests/integration/query_processing/test_write_path.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/query_processing/test_write_path.rs b/tests/integration/query_processing/test_write_path.rs index 096327db1..60cd6495a 100644 --- a/tests/integration/query_processing/test_write_path.rs +++ b/tests/integration/query_processing/test_write_path.rs @@ -284,7 +284,7 @@ fn test_wal_checkpoint() -> anyhow::Result<()> { let conn = tmp_db.connect_limbo(); for i in 0..iterations { - log::info!("iteration #{}", i); + log::info!("iteration #{i}"); let insert_query = format!("INSERT INTO test VALUES ({i})"); do_flush(&conn, &tmp_db)?; conn.checkpoint(CheckpointMode::Passive { From 5d3780f25df53978f4d3289c4a02bc5ecd2f3cb3 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Aug 2025 11:12:41 +0300 Subject: [PATCH 28/73] core/translate: Add `CREATE INDEX IF NOT EXISTS` support Fixes #2263 --- core/translate/index.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/translate/index.rs b/core/translate/index.rs index ef574f6af..7d332f0d2 100644 --- a/core/translate/index.rs +++ b/core/translate/index.rs @@ -44,6 +44,10 @@ pub fn translate_create_index( // Check if the index is being created on a valid btree table and // the name is globally unique in the schema. if !schema.is_unique_idx_name(&idx_name) { + // If IF NOT EXISTS is specified, silently return without error + if unique_if_not_exists.1 { + return Ok(program); + } crate::bail_parse_error!("Error: index with name '{idx_name}' already exists."); } let Some(tbl) = schema.tables.get(&tbl_name) else { From 5fe5e1548b2fe5b04117a20b31b7a9bc49f3ea7e Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Aug 2025 11:21:46 +0300 Subject: [PATCH 29/73] core/io: Fix build on Android and iOS Commit ebe6aa0d28136753a03868c2990f714d1e1bbaf1 ("adjust cfg for unix and linux IO") adjusted the I/O conditional compilation, but forgot that Android and iOS are also part of Unix target family. Fixes #2500 --- core/io/mod.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/core/io/mod.rs b/core/io/mod.rs index f376b5a3c..1fded463e 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -506,15 +506,7 @@ cfg_block! { pub use PlatformIO as SyscallIO; } - #[cfg(any(target_os = "android", target_os = "ios"))] { - mod unix; - #[cfg(feature = "fs")] - pub use unix::UnixIO; - pub use unix::UnixIO as SyscallIO; - pub use unix::UnixIO as PlatformIO; - } - - #[cfg(target_os = "windows")] { + #[cfg(target_os = "windows")] { mod windows; pub use windows::WindowsIO as PlatformIO; pub use PlatformIO as SyscallIO; From f7ad55b6809e9f04d2db1335cf95e536e970239a Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 25 Aug 2025 12:14:54 +0400 Subject: [PATCH 30/73] remove unnecessary argument --- core/lib.rs | 3 +-- core/mvcc/database/mod.rs | 1 - core/storage/btree.rs | 15 +++------------ core/storage/pager.rs | 4 ++-- core/vdbe/execute.rs | 2 +- core/vdbe/mod.rs | 12 +++--------- 6 files changed, 10 insertions(+), 27 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index e8e629887..7f35f8861 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1500,7 +1500,6 @@ impl Connection { pager.end_tx( true, // rollback = true for close self, - self.wal_auto_checkpoint_disabled.get(), ) })?; self.transaction_state.set(TransactionState::None); @@ -2106,7 +2105,7 @@ impl Statement { } let state = self.program.connection.transaction_state.get(); if let TransactionState::Write { .. } = state { - let end_tx_res = self.pager.end_tx(true, &self.program.connection, true)?; + let end_tx_res = self.pager.end_tx(true, &self.program.connection)?; self.program .connection .transaction_state diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index bb1a0608e..13e28c437 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -546,7 +546,6 @@ impl StateTransition for CommitStateMachine { .end_tx( false, // rollback = false since we're committing &self.connection, - self.connection.wal_auto_checkpoint_disabled.get(), ) .map_err(|e| LimboError::InternalError(e.to_string())) .unwrap(); diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 15e43d25c..67d1be044 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7915,10 +7915,7 @@ mod tests { pager.deref(), ) .unwrap(); - pager - .io - .block(|| pager.end_tx(false, &conn, false)) - .unwrap(); + pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); pager.begin_read_tx().unwrap(); // FIXME: add sorted vector instead, should be okay for small amounts of keys for now :P, too lazy to fix right now let _c = cursor.move_to_root().unwrap(); @@ -8063,10 +8060,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager - .io - .block(|| pager.end_tx(false, &conn, false)) - .unwrap(); + pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); } // Check that all keys can be found by seeking @@ -8272,10 +8266,7 @@ mod tests { if let Some(c) = c { pager.io.wait_for_completion(c).unwrap(); } - pager - .io - .block(|| pager.end_tx(false, &conn, false)) - .unwrap(); + pager.io.block(|| pager.end_tx(false, &conn)).unwrap(); } // Final validation diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 85f33775e..a49eb2ff6 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1026,7 +1026,6 @@ impl Pager { &self, rollback: bool, connection: &Connection, - wal_auto_checkpoint_disabled: bool, ) -> Result> { if connection.is_nested_stmt.get() { // Parent statement will handle the transaction rollback. @@ -1050,7 +1049,8 @@ impl Pager { self.rollback(schema_did_change, connection, is_write)?; return Ok(IOResult::Done(PagerCommitResult::Rollback)); } - let commit_status = return_if_io!(self.commit_dirty_pages(wal_auto_checkpoint_disabled)); + let commit_status = + return_if_io!(self.commit_dirty_pages(connection.wal_auto_checkpoint_disabled.get())); wal.borrow().end_write_tx(); wal.borrow().end_read_tx(); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index c910d827f..849b37839 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -2170,7 +2170,7 @@ pub fn op_auto_commit( if *auto_commit != conn.auto_commit.get() { if *rollback { // TODO(pere): add rollback I/O logic once we implement rollback journal - return_if_io!(pager.end_tx(true, &conn, false)); + return_if_io!(pager.end_tx(true, &conn)); conn.transaction_state.replace(TransactionState::None); conn.auto_commit.replace(true); } else { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index a12e5772e..afb4485ac 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -434,9 +434,7 @@ impl Program { // Connection is closed for whatever reason, rollback the transaction. let state = self.connection.transaction_state.get(); if let TransactionState::Write { .. } = state { - pager - .io - .block(|| pager.end_tx(true, &self.connection, false))?; + pager.io.block(|| pager.end_tx(true, &self.connection))?; } return Err(LimboError::InternalError("Connection closed".to_string())); } @@ -608,11 +606,7 @@ impl Program { connection: &Connection, rollback: bool, ) -> Result> { - let cacheflush_status = pager.end_tx( - rollback, - connection, - connection.wal_auto_checkpoint_disabled.get(), - )?; + let cacheflush_status = pager.end_tx(rollback, connection)?; match cacheflush_status { IOResult::Done(_) => { if self.change_cnt_on { @@ -869,7 +863,7 @@ pub fn handle_program_error( _ => { pager .io - .block(|| pager.end_tx(true, connection, false)) + .block(|| pager.end_tx(true, connection)) .inspect_err(|e| { tracing::error!("end_tx failed: {e}"); })?; From cf7418663ccdb08e6a81b1da4568716742b32255 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 16:40:09 +0530 Subject: [PATCH 31/73] update encryption tests to work with diff ciphers --- .../query_processing/encryption.rs | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/integration/query_processing/encryption.rs b/tests/integration/query_processing/encryption.rs index 3cca08587..7fdaec29d 100644 --- a/tests/integration/query_processing/encryption.rs +++ b/tests/integration/query_processing/encryption.rs @@ -50,6 +50,62 @@ fn test_per_page_encryption() -> anyhow::Result<()> { should_panic.is_err(), "should panic when accessing encrypted DB without key" ); + + // it should also panic if we specify either only key or cipher + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';").unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB without key" + ); + + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query( + &tmp_db, + &conn, + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", + ).unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB without cipher name" + ); + + // it should panic if we specify wrong cipher or key + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query( + &tmp_db, + &conn, + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76327';", + ).unwrap(); + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aes256gcm';").unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB with incorrect cipher" + ); + + let conn = tmp_db.connect_limbo(); + let should_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + run_query(&tmp_db, &conn, "PRAGMA cipher = 'aegis256';").unwrap(); + run_query( + &tmp_db, + &conn, + "PRAGMA hexkey = 'b1bbfda4f589dc9daaf004fe21111e00dc00c98237102f5c7002a5669fc76377';", + ).unwrap(); + run_query_on_row(&tmp_db, &conn, "SELECT * FROM test", |_: &Row| {}).unwrap(); + })); + assert!( + should_panic.is_err(), + "should panic when accessing encrypted DB with incorrect key" + ); } { From b85ba09014d871f4784e741b218f2d14acb7e288 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 16:51:19 +0530 Subject: [PATCH 32/73] Fix clippy boss' complaints --- core/lib.rs | 4 ++-- core/storage/encryption.rs | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 46a479a54..44a802db9 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1971,7 +1971,7 @@ impl Connection { } pub fn get_encryption_cipher_mode(&self) -> Option { - self.encryption_cipher.borrow().clone() + *self.encryption_cipher.borrow() } // if both key and cipher are set, set encryption context on pager @@ -1980,7 +1980,7 @@ impl Connection { let Some(key) = key_ref.as_ref() else { return; }; - let Some(cipher) = self.encryption_cipher.borrow().clone() else { + let Some(cipher) = *self.encryption_cipher.borrow() else { return; }; tracing::trace!("setting encryption ctx for connection"); diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 17343f1c0..51cf84ee5 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -21,7 +21,7 @@ impl EncryptionKey { pub fn from_hex_string(s: &str) -> Result { let hex_str = s.trim(); let bytes = hex::decode(hex_str) - .map_err(|e| LimboError::InvalidArgument(format!("Invalid hex string: {}", e)))?; + .map_err(|e| LimboError::InvalidArgument(format!("Invalid hex string: {e}")))?; let key: [u8; 32] = bytes.try_into().map_err(|v: Vec| { LimboError::InvalidArgument(format!( "Hex string must decode to exactly 32 bytes, got {}", @@ -142,8 +142,7 @@ impl TryFrom<&str> for CipherMode { "aes256gcm" | "aes-256-gcm" | "aes_256_gcm" => Ok(CipherMode::Aes256Gcm), "aegis256" | "aegis-256" | "aegis_256" => Ok(CipherMode::Aegis256), _ => Err(LimboError::InvalidArgument(format!( - "Unknown cipher name: {}", - s + "Unknown cipher name: {s}" ))), } } From 16547cb569bcd8e743700cb716598e17fe86110a Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Aug 2025 14:25:26 +0300 Subject: [PATCH 33/73] sqlite3: Implement sqlite3_next_stmt() --- sqlite3/include/sqlite3.h | 2 + sqlite3/src/lib.rs | 49 +++++++++++++++++++++- sqlite3/tests/compat/mod.rs | 78 +++++++++++++++++++++++++++++++++++ sqlite3/tests/sqlite3_tests.c | 2 + 4 files changed, 130 insertions(+), 1 deletion(-) diff --git a/sqlite3/include/sqlite3.h b/sqlite3/include/sqlite3.h index f56f38b8c..0d098ce81 100644 --- a/sqlite3/include/sqlite3.h +++ b/sqlite3/include/sqlite3.h @@ -107,6 +107,8 @@ int sqlite3_stmt_readonly(sqlite3_stmt *_stmt); int sqlite3_stmt_busy(sqlite3_stmt *_stmt); +sqlite3_stmt *sqlite3_next_stmt(sqlite3 *db, sqlite3_stmt *stmt); + int sqlite3_serialize(sqlite3 *_db, const char *_schema, void **_out, int *_out_bytes, unsigned int _flags); int sqlite3_deserialize(sqlite3 *_db, const char *_schema, const void *_in_, int _in_bytes, unsigned int _flags); diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 46d6d64b3..1d0fa1bed 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -57,6 +57,7 @@ struct sqlite3Inner { pub(crate) e_open_state: u8, pub(crate) p_err: *mut ffi::c_void, pub(crate) filename: CString, + pub(crate) stmt_list: *mut sqlite3_stmt, } impl sqlite3 { @@ -76,6 +77,7 @@ impl sqlite3 { e_open_state: SQLITE_STATE_OPEN, p_err: std::ptr::null_mut(), filename, + stmt_list: std::ptr::null_mut(), }; #[allow(clippy::arc_with_non_send_sync)] let inner = Arc::new(Mutex::new(inner)); @@ -91,6 +93,7 @@ pub struct sqlite3_stmt { Option, *mut ffi::c_void, )>, + pub(crate) next: *mut sqlite3_stmt, } impl sqlite3_stmt { @@ -99,6 +102,7 @@ impl sqlite3_stmt { db, stmt, destructors: Vec::new(), + next: std::ptr::null_mut(), } } } @@ -279,7 +283,12 @@ pub unsafe extern "C" fn sqlite3_prepare_v2( return SQLITE_ERROR; } }; - *out_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); + let new_stmt = Box::leak(Box::new(sqlite3_stmt::new(raw_db, stmt))); + + new_stmt.next = db.stmt_list; + db.stmt_list = new_stmt; + + *out_stmt = new_stmt; SQLITE_OK } @@ -290,6 +299,25 @@ pub unsafe extern "C" fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> ffi::c_int } let stmt_ref = &mut *stmt; + if !stmt_ref.db.is_null() { + let db = &mut *stmt_ref.db; + let mut db_inner = db.inner.lock().unwrap(); + + if db_inner.stmt_list == stmt { + db_inner.stmt_list = stmt_ref.next; + } else { + let mut current = db_inner.stmt_list; + while !current.is_null() { + let current_ref = &mut *current; + if current_ref.next == stmt { + current_ref.next = stmt_ref.next; + break; + } + current = current_ref.next; + } + } + } + for (_idx, destructor_opt, ptr) in stmt_ref.destructors.drain(..) { if let Some(destructor_fn) = destructor_opt { destructor_fn(ptr); @@ -381,6 +409,25 @@ pub unsafe extern "C" fn sqlite3_stmt_busy(_stmt: *mut sqlite3_stmt) -> ffi::c_i stub!(); } +/// Iterate over all prepared statements in the database. +#[no_mangle] +pub unsafe extern "C" fn sqlite3_next_stmt( + db: *mut sqlite3, + stmt: *mut sqlite3_stmt, +) -> *mut sqlite3_stmt { + if db.is_null() { + return std::ptr::null_mut(); + } + if stmt.is_null() { + let db = &*db; + let db = db.inner.lock().unwrap(); + db.stmt_list + } else { + let stmt = &mut *stmt; + stmt.next + } +} + #[no_mangle] pub unsafe extern "C" fn sqlite3_serialize( _db: *mut sqlite3, diff --git a/sqlite3/tests/compat/mod.rs b/sqlite3/tests/compat/mod.rs index 94361cf11..52ed3d8fa 100644 --- a/sqlite3/tests/compat/mod.rs +++ b/sqlite3/tests/compat/mod.rs @@ -48,6 +48,7 @@ extern "C" { ) -> i32; fn libsql_wal_disable_checkpoint(db: *mut sqlite3) -> i32; fn sqlite3_column_int(stmt: *mut sqlite3_stmt, idx: i32) -> i64; + fn sqlite3_next_stmt(db: *mut sqlite3, stmt: *mut sqlite3_stmt) -> *mut sqlite3_stmt; fn sqlite3_bind_int(stmt: *mut sqlite3_stmt, idx: i32, val: i64) -> i32; fn sqlite3_bind_parameter_count(stmt: *mut sqlite3_stmt) -> i32; fn sqlite3_bind_parameter_name(stmt: *mut sqlite3_stmt, idx: i32) -> *const libc::c_char; @@ -1319,4 +1320,81 @@ mod tests { assert_eq!(sqlite3_close(db), SQLITE_OK); } } + + #[test] + fn test_sqlite3_next_stmt() { + const SQLITE_OK: i32 = 0; + + unsafe { + let mut db: *mut sqlite3 = ptr::null_mut(); + assert_eq!(sqlite3_open(c":memory:".as_ptr(), &mut db), SQLITE_OK); + + // Initially, there should be no prepared statements + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert!(iter.is_null()); + + // Prepare first statement + let mut stmt1: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 1;".as_ptr(), -1, &mut stmt1, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt1.is_null()); + + // Now there should be one statement + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert_eq!(iter, stmt1); + + // And no more after that + let iter = sqlite3_next_stmt(db, stmt1); + assert!(iter.is_null()); + + // Prepare second statement + let mut stmt2: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 2;".as_ptr(), -1, &mut stmt2, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt2.is_null()); + + // Prepare third statement + let mut stmt3: *mut sqlite3_stmt = ptr::null_mut(); + assert_eq!( + sqlite3_prepare_v2(db, c"SELECT 3;".as_ptr(), -1, &mut stmt3, ptr::null_mut()), + SQLITE_OK + ); + assert!(!stmt3.is_null()); + + // Count all statements + let mut count = 0; + let mut iter = sqlite3_next_stmt(db, ptr::null_mut()); + while !iter.is_null() { + count += 1; + iter = sqlite3_next_stmt(db, iter); + } + assert_eq!(count, 3); + + // Finalize the middle statement + assert_eq!(sqlite3_finalize(stmt2), SQLITE_OK); + + // Count should now be 2 + count = 0; + iter = sqlite3_next_stmt(db, ptr::null_mut()); + while !iter.is_null() { + count += 1; + iter = sqlite3_next_stmt(db, iter); + } + assert_eq!(count, 2); + + // Finalize remaining statements + assert_eq!(sqlite3_finalize(stmt1), SQLITE_OK); + assert_eq!(sqlite3_finalize(stmt3), SQLITE_OK); + + // Should be no statements left + let iter = sqlite3_next_stmt(db, ptr::null_mut()); + assert!(iter.is_null()); + + assert_eq!(sqlite3_close(db), SQLITE_OK); + } + } } diff --git a/sqlite3/tests/sqlite3_tests.c b/sqlite3/tests/sqlite3_tests.c index 2cd490a93..9fc1ffc49 100644 --- a/sqlite3/tests/sqlite3_tests.c +++ b/sqlite3/tests/sqlite3_tests.c @@ -18,6 +18,7 @@ void test_sqlite3_bind_text2(); void test_sqlite3_bind_blob(); void test_sqlite3_column_type(); void test_sqlite3_column_decltype(); +void test_sqlite3_next_stmt(); int allocated = 0; @@ -35,6 +36,7 @@ int main(void) test_sqlite3_bind_blob(); test_sqlite3_column_type(); test_sqlite3_column_decltype(); + test_sqlite3_next_stmt(); return 0; } From 42c8a77bb7d3a35c7889f3ce0d6d7043c6458a08 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Mon, 25 Aug 2025 15:03:10 +0300 Subject: [PATCH 34/73] use existing payload_overflows() utility in local space calculation --- core/storage/btree.rs | 52 +++++++++++++------------------------------ 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 962e5d51a..b919e7c0e 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -6,9 +6,10 @@ use crate::{ storage::{ pager::{BtreePageAllocMode, Pager}, sqlite3_ondisk::{ - read_u32, read_varint, BTreeCell, DatabaseHeader, PageContent, PageSize, PageType, - TableInteriorCell, TableLeafCell, CELL_PTR_SIZE_BYTES, INTERIOR_PAGE_HEADER_SIZE_BYTES, - LEAF_PAGE_HEADER_SIZE_BYTES, LEFT_CHILD_PTR_SIZE_BYTES, + payload_overflows, read_u32, read_varint, BTreeCell, DatabaseHeader, PageContent, + PageSize, PageType, TableInteriorCell, TableLeafCell, CELL_PTR_SIZE_BYTES, + INTERIOR_PAGE_HEADER_SIZE_BYTES, LEAF_PAGE_HEADER_SIZE_BYTES, + LEFT_CHILD_PTR_SIZE_BYTES, }, state_machines::{ AdvanceState, CountState, EmptyTableState, MoveToRightState, MoveToState, RewindState, @@ -7040,48 +7041,27 @@ fn fill_cell_payload( write_varint_to_vec(record_buf.len() as u64, cell_payload); } - let payload_overflow_threshold_max = - payload_overflow_threshold_max(page_type, usable_space); - tracing::debug!( - "fill_cell_payload(record_size={}, payload_overflow_threshold_max={})", - record_buf.len(), - payload_overflow_threshold_max - ); - if record_buf.len() <= payload_overflow_threshold_max { + let max_local = payload_overflow_threshold_max(page_type, usable_space); + let min_local = payload_overflow_threshold_min(page_type, usable_space); + + let (overflows, local_size_if_overflow) = + payload_overflows(record_buf.len(), max_local, min_local, usable_space); + if !overflows { // enough allowed space to fit inside a btree page cell_payload.extend_from_slice(record_buf.as_ref()); break; } - let payload_overflow_threshold_min = - payload_overflow_threshold_min(page_type, usable_space); - // see e.g. https://github.com/sqlite/sqlite/blob/9591d3fe93936533c8c3b0dc4d025ac999539e11/src/dbstat.c#L371 - let mut space_left = payload_overflow_threshold_min - + (record_buf.len() - payload_overflow_threshold_min) % overflow_page_data_size; - - if space_left > payload_overflow_threshold_max { - space_left = payload_overflow_threshold_min; - } - - // cell_size must be equal to first value of space_left as this will be the bytes copied to non-overflow page. - let cell_size = space_left + cell_payload.len() + overflow_page_pointer_size; - - let prev_size = cell_payload.len(); - let new_data_size = prev_size + space_left; - cell_payload.resize(new_data_size + overflow_page_pointer_size, 0); - assert_eq!( - cell_size, - cell_payload.len(), - "cell_size={} != cell_payload.len()={}", - cell_size, - cell_payload.len() - ); + // so far we've written any of: left child page, rowid, payload size (depending on page type) + let cell_non_payload_elems_size = cell_payload.len(); + let new_total_local_size = cell_non_payload_elems_size + local_size_if_overflow; + cell_payload.resize(new_total_local_size, 0); *fill_cell_payload_state = FillCellPayloadState::CopyData { state: CopyDataState::Copy, - space_left_on_cur_page: space_left, + space_left_on_cur_page: local_size_if_overflow - overflow_page_pointer_size, // local_size_if_overflow includes the overflow page pointer, but we don't want to write payload data there. src_data_offset: 0, - dst_data_offset: prev_size, + dst_data_offset: cell_non_payload_elems_size, current_overflow_page: None, }; continue; From c6553d82b84ec7cd29a31329923a64a83e10f204 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Mon, 25 Aug 2025 15:05:04 +0300 Subject: [PATCH 35/73] Clarify expected behavior with assertion --- core/storage/btree.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index b919e7c0e..30530d691 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7099,10 +7099,9 @@ fn fill_cell_payload( } if record_offset_slice.len() - amount_to_copy == 0 { - if let Some(cur_page) = current_overflow_page { - cur_page.unpin(); // We can safely unpin the current overflow page now. - } - // Everything copied. + let cur_page = current_overflow_page.as_ref().expect("we must have overflowed if the remaining payload fits on the current page"); + cur_page.unpin(); // We can safely unpin the current overflow page now. + // Everything copied. break; } *state = CopyDataState::AllocateOverflowPage; From 16b1ae4a9f40b7b650c35a260f5f2044043ff061 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Mon, 25 Aug 2025 15:12:37 +0300 Subject: [PATCH 36/73] Handle unpinning btree page in case of allocate overflow page error --- core/storage/btree.rs | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 30530d691..d3e9d50d3 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -7017,7 +7017,7 @@ fn fill_cell_payload( ) -> Result> { let overflow_page_pointer_size = 4; let overflow_page_data_size = usable_space - overflow_page_pointer_size; - loop { + let result = loop { let record_buf = record.get_payload(); match fill_cell_payload_state { FillCellPayloadState::Start => { @@ -7049,7 +7049,7 @@ fn fill_cell_payload( if !overflows { // enough allowed space to fit inside a btree page cell_payload.extend_from_slice(record_buf.as_ref()); - break; + break Ok(IOResult::Done(())); } // so far we've written any of: left child page, rowid, payload size (depending on page type) @@ -7102,13 +7102,22 @@ fn fill_cell_payload( let cur_page = current_overflow_page.as_ref().expect("we must have overflowed if the remaining payload fits on the current page"); cur_page.unpin(); // We can safely unpin the current overflow page now. // Everything copied. - break; + break Ok(IOResult::Done(())); } *state = CopyDataState::AllocateOverflowPage; *src_data_offset += amount_to_copy; } CopyDataState::AllocateOverflowPage => { - let new_overflow_page = return_if_io!(pager.allocate_overflow_page()); + let new_overflow_page = match pager.allocate_overflow_page() { + Ok(IOResult::Done(new_overflow_page)) => new_overflow_page, + Ok(IOResult::IO(io_result)) => return Ok(IOResult::IO(io_result)), + Err(e) => { + if let Some(cur_page) = current_overflow_page { + cur_page.unpin(); + } + break Err(e); + } + }; new_overflow_page.pin(); // Pin the current overflow page so the cache won't evict it because we need this page to be in memory for the next iteration of FillCellPayloadState::CopyData. if let Some(prev_page) = current_overflow_page { prev_page.unpin(); // We can safely unpin the previous overflow page now. @@ -7146,9 +7155,9 @@ fn fill_cell_payload( } } } - } + }; page.unpin(); - Ok(IOResult::Done(())) + result } /// Returns the maximum payload size (X) that can be stored directly on a b-tree page without spilling to overflow pages. /// From 40b7e3bf5a59206f7bb343f0a4729934c5353928 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Mon, 25 Aug 2025 19:16:15 +0530 Subject: [PATCH 37/73] rename `cipher` to `cipher_mode` for consistency --- core/lib.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 44a802db9..51197ccae 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -456,7 +456,7 @@ impl Database { metrics: RefCell::new(ConnectionMetrics::new()), is_nested_stmt: Cell::new(false), encryption_key: RefCell::new(None), - encryption_cipher: RefCell::new(None), + encryption_cipher_mode: RefCell::new(None), }); let builtin_syms = self.builtin_syms.borrow(); // add built-in extensions symbols to the connection to prevent having to load each time @@ -888,7 +888,7 @@ pub struct Connection { /// Generally this is only true for ParseSchema. is_nested_stmt: Cell, encryption_key: RefCell>, - encryption_cipher: RefCell>, + encryption_cipher_mode: RefCell>, } impl Connection { @@ -1964,14 +1964,14 @@ impl Connection { self.set_encryption_context(); } - pub fn set_encryption_cipher(&self, cipher: CipherMode) { + pub fn set_encryption_cipher(&self, cipher_mode: CipherMode) { tracing::trace!("setting encryption cipher for connection"); - self.encryption_cipher.replace(Some(cipher)); + self.encryption_cipher_mode.replace(Some(cipher_mode)); self.set_encryption_context(); } pub fn get_encryption_cipher_mode(&self) -> Option { - *self.encryption_cipher.borrow() + *self.encryption_cipher_mode.borrow() } // if both key and cipher are set, set encryption context on pager @@ -1980,12 +1980,12 @@ impl Connection { let Some(key) = key_ref.as_ref() else { return; }; - let Some(cipher) = *self.encryption_cipher.borrow() else { + let Some(cipher_mode) = *self.encryption_cipher_mode.borrow() else { return; }; tracing::trace!("setting encryption ctx for connection"); let pager = self.pager.borrow(); - pager.set_encryption_context(cipher, key); + pager.set_encryption_context(cipher_mode, key); } } From 46e288ac26d13b52de37ebb8969c28a636497a39 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 23 Aug 2025 16:37:36 -0400 Subject: [PATCH 38/73] Add append_frames_vectored to WAL api In addition to the existing `append_frame` which will write an individual frame to the WAL, we add a method `append_frames_vectored` that takes N frames and the db size which will need to be set for the last (commit) frame, and it calculates the checksums and submits them as a single `pwritev` call, reducing the number of syscalls needed for each write operation. --- core/storage/wal.rs | 115 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 1 deletion(-) diff --git a/core/storage/wal.rs b/core/storage/wal.rs index cc5c8e77f..ad78dec07 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1,6 +1,7 @@ #![allow(clippy::not_unsafe_ptr_arg_deref)] use std::array; +use std::borrow::Cow; use std::cell::{RefCell, UnsafeCell}; use std::collections::{BTreeMap, HashMap, HashSet}; use strum::EnumString; @@ -275,6 +276,13 @@ pub trait Wal: Debug { db_size: u32, ) -> Result; + fn append_frames_vectored( + &mut self, + pages: Vec, + page_sz: PageSize, + db_size_on_commit: Option, + ) -> Result; + /// Complete append of frames by updating shared wal state. Before this /// all changes were stored locally. fn finish_append_frames_commit(&mut self) -> Result<()>; @@ -317,7 +325,8 @@ pub const CKPT_BATCH_PAGES: usize = 512; const MIN_AVG_RUN_FOR_FLUSH: f32 = 32.0; const MIN_BATCH_LEN_FOR_FLUSH: usize = 512; const MAX_INFLIGHT_WRITES: usize = 64; -const MAX_INFLIGHT_READS: usize = 512; +pub const MAX_INFLIGHT_READS: usize = 512; +pub const IOV_MAX: usize = 1024; type PageId = usize; struct InflightRead { @@ -1391,6 +1400,110 @@ impl Wal for WalFile { Ok(pages) } + /// Use pwritev to append many frames to the log at once + fn append_frames_vectored( + &mut self, + pages: Vec, + page_sz: PageSize, + db_size_on_commit: Option, + ) -> Result { + turso_assert!( + pages.len() <= IOV_MAX, + "we limit number of iovecs to IOV_MAX" + ); + self.ensure_header_if_needed(page_sz)?; + + let (header, shared_page_size, seq) = { + let shared = self.get_shared(); + let hdr_guard = shared.wal_header.lock(); + let header: WalHeader = *hdr_guard; + let shared_page_size = header.page_size; + let seq = header.checkpoint_seq; + (header, shared_page_size, seq) + }; + turso_assert!( + shared_page_size == page_sz.get(), + "page size mismatch, tried to change page size after WAL header was already initialized: shared.page_size={shared_page_size}, page_size={}", + page_sz.get() + ); + + // Prepare write buffers and bookkeeping + let mut iovecs: Vec> = Vec::with_capacity(pages.len()); + let mut page_frame_and_checksum: Vec<(PageRef, u64, (u32, u32))> = + Vec::with_capacity(pages.len()); + + // Rolling checksum input to each frame build + let mut rolling_csum: (u32, u32) = self.last_checksum; + + let mut next_frame_id = self.max_frame + 1; + // Build every frame in order, updating the rolling checksum + for (idx, page) in pages.iter().enumerate() { + let page_id = page.get().id as u64; + let plain = page.get_contents().as_ptr(); + + let data_to_write: std::borrow::Cow<[u8]> = { + let key = self.encryption_key.borrow(); + if let Some(k) = key.as_ref() { + Cow::Owned(encrypt_page(plain, page_id as usize, k)?) + } else { + Cow::Borrowed(plain) + } + }; + + let frame_db_size = if idx + 1 == pages.len() { + db_size_on_commit.unwrap_or(0) + } else { + 0 + }; + let (new_csum, frame_bytes) = prepare_wal_frame( + &self.buffer_pool, + &header, + rolling_csum, + shared_page_size, + page_id as u32, + frame_db_size, + &data_to_write, + ); + iovecs.push(frame_bytes); + + // (page, assigned_frame_id, cumulative_checksum_at_this_frame) + page_frame_and_checksum.push((page.clone(), next_frame_id, new_csum)); + + // Advance for the next frame + rolling_csum = new_csum; + next_frame_id += 1; + } + + let first_frame_id = self.max_frame + 1; + let start_off = self.frame_offset(first_frame_id); + + // pre-advance in-memory WAL state like the single-frame path + for (page, fid, csum) in &page_frame_and_checksum { + self.complete_append_frame(page.get().id as u64, *fid, *csum); + } + + // single completion for the whole batch + let total_len: i32 = iovecs.iter().map(|b| b.len() as i32).sum(); + let page_frame_for_cb = page_frame_and_checksum.clone(); + let c = Completion::new_write(move |res: Result| { + let Ok(bytes_written) = res else { + return; + }; + turso_assert!( + bytes_written == total_len, + "pwritev wrote {bytes_written} bytes, expected {total_len}" + ); + + for (page, fid, _csum) in &page_frame_for_cb { + page.clear_dirty(); + page.set_wal_tag(*fid, seq); + } + }); + + let c = self.get_shared().file.pwritev(start_off, iovecs, c)?; + Ok(c) + } + #[cfg(debug_assertions)] fn as_any(&self) -> &dyn std::any::Any { self From 02390887189a0e3685c3a947b88c2f6c4411c47c Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 23 Aug 2025 16:52:26 -0400 Subject: [PATCH 39/73] Use new append_frames_vectored WAL method to flush pager cache and commit write tx --- core/storage/pager.rs | 147 +++++++++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 52 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 17cfbc102..c3321b382 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1,4 +1,5 @@ use crate::result::LimboResult; +use crate::storage::wal::IOV_MAX; use crate::storage::{ btree::BTreePageInner, buffer_pool::BufferPool, @@ -346,8 +347,10 @@ pub enum BtreePageAllocMode { } /// This will keep track of the state of current cache commit in order to not repeat work +#[derive(Clone)] struct CommitInfo { state: Cell, + time: Cell, } /// Track the state of the auto-vacuum mode. @@ -563,6 +566,7 @@ impl Pager { } else { RefCell::new(AllocatePage1State::Done) }; + let now = io.now(); Ok(Self { db_file, wal, @@ -573,6 +577,7 @@ impl Pager { ))), commit_info: CommitInfo { state: CommitState::Start.into(), + time: now.into(), }, syncing: Rc::new(Cell::new(false)), checkpoint_state: RefCell::new(CheckpointState::Checkpoint), @@ -1250,36 +1255,46 @@ impl Pager { .iter() .copied() .collect::>(); - let mut completions: Vec = Vec::with_capacity(dirty_pages.len()); - for page_id in dirty_pages { + let len = dirty_pages.len().min(IOV_MAX); + let mut completions: Vec = Vec::new(); + let mut pages = Vec::with_capacity(len); + let page_sz = self.page_size.get().unwrap_or_default(); + let commit_frame = None; // cacheflush only so we are not setting a commit frame here + for (idx, page_id) in dirty_pages.iter().enumerate() { let page = { let mut cache = self.page_cache.write(); - let page_key = PageCacheKey::new(page_id); + let page_key = PageCacheKey::new(*page_id); let page = cache.get(&page_key).expect("we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it."); let page_type = page.get().contents.as_ref().unwrap().maybe_page_type(); - trace!( - "commit_dirty_pages(page={}, page_type={:?}", - page_id, - page_type - ); + trace!("cacheflush(page={}, page_type={:?}", page_id, page_type); page }; + pages.push(page); + if pages.len() == IOV_MAX { + let c = wal + .borrow_mut() + .append_frames_vectored( + std::mem::replace( + &mut pages, + Vec::with_capacity(std::cmp::min(IOV_MAX, dirty_pages.len() - idx)), + ), + page_sz, + commit_frame, + ) + .inspect_err(|_| { + for c in completions.iter() { + c.abort(); + } + })?; + completions.push(c); + } + } + if !pages.is_empty() { let c = wal .borrow_mut() - .append_frame( - page.clone(), - self.page_size.get().expect("page size not set"), - 0, - ) - .inspect_err(|_| { - for c in completions.iter() { - c.abort(); - } - })?; - // TODO: invalidade previous completions if this one fails + .append_frames_vectored(pages, page_sz, commit_frame)?; completions.push(c); } - // Pages are cleared dirty on callback completion Ok(completions) } @@ -1297,57 +1312,70 @@ impl Pager { "commit_dirty_pages() called on database without WAL".to_string(), )); }; + let mut checkpoint_result = CheckpointResult::default(); let res = loop { let state = self.commit_info.state.get(); trace!(?state); match state { CommitState::Start => { - let db_size = { + let now = self.io.now(); + self.commit_info.time.set(now); + let db_size_after = { self.io .block(|| self.with_header(|header| header.database_size))? .get() }; - let dirty_len = self.dirty_pages.borrow().iter().len(); - let mut completions: Vec = Vec::with_capacity(dirty_len); - for (curr_page_idx, page_id) in - self.dirty_pages.borrow().iter().copied().enumerate() - { - let is_last_frame = curr_page_idx == dirty_len - 1; - let db_size = if is_last_frame { db_size } else { 0 }; + let dirty_ids: Vec = self.dirty_pages.borrow().iter().copied().collect(); + if dirty_ids.is_empty() { + return Ok(IOResult::Done(PagerCommitResult::WalWritten)); + } + let page_sz = self.page_size.get().expect("page size not set"); + let mut completions: Vec = Vec::new(); + let mut pages: Vec = Vec::with_capacity(dirty_ids.len().min(IOV_MAX)); + let total = dirty_ids.len(); + + for (i, page_id) in dirty_ids.into_iter().enumerate() { let page = { let mut cache = self.page_cache.write(); let page_key = PageCacheKey::new(page_id); - let page = cache.get(&page_key).unwrap_or_else(|| { - panic!( - "we somehow added a page to dirty list but we didn't mark it as dirty, causing cache to drop it. page={page_id}" - ) - }); - let page_type = page.get().contents.as_ref().unwrap().maybe_page_type(); + let page = cache.get(&page_key).expect( + "dirty list contained a page that cache dropped (page={page_id})", + ); trace!( - "commit_dirty_pages(page={}, page_type={:?}", + "commit_dirty_pages(page={}, page_type={:?})", page_id, - page_type + page.get().contents.as_ref().unwrap().maybe_page_type() ); page }; + pages.push(page); - // TODO: invalidade previous completions on error here - let c = wal - .borrow_mut() - .append_frame( - page.clone(), - self.page_size.get().expect("page size not set"), - db_size, - ) - .inspect_err(|_| { - for c in completions.iter() { - c.abort(); + let end_of_chunk = pages.len() == IOV_MAX || i == total - 1; + if end_of_chunk { + let commit_flag = if i == total - 1 { + // Only the commit frame (final) frame carries the db_size + Some(db_size_after) + } else { + None + }; + let r = wal.borrow_mut().append_frames_vectored( + std::mem::take(&mut pages), + page_sz, + commit_flag, + ); + match r { + Ok(c) => completions.push(c), + Err(e) => { + for c in &completions { + c.abort(); + } + return Err(e); } - })?; - completions.push(c); + } + } } self.dirty_pages.borrow_mut().clear(); // Nothing to append @@ -1355,13 +1383,17 @@ impl Pager { return Ok(IOResult::Done(PagerCommitResult::WalWritten)); } else { self.commit_info.state.set(CommitState::SyncWal); + } + if !completions.iter().all(|c| c.is_completed()) { io_yield_many!(completions); } } CommitState::SyncWal => { self.commit_info.state.set(CommitState::AfterSyncWal); let c = wal.borrow_mut().sync()?; - io_yield_one!(c); + if !c.is_completed() { + io_yield_one!(c); + } } CommitState::AfterSyncWal => { turso_assert!(!wal.borrow().is_syncing(), "wal should have synced"); @@ -1378,7 +1410,9 @@ impl Pager { CommitState::SyncDbFile => { let c = sqlite3_ondisk::begin_sync(self.db_file.clone(), self.syncing.clone())?; self.commit_info.state.set(CommitState::AfterSyncDbFile); - io_yield_one!(c); + if !c.is_completed() { + io_yield_one!(c); + } } CommitState::AfterSyncDbFile => { turso_assert!(!self.syncing.get(), "should have finished syncing"); @@ -1387,7 +1421,15 @@ impl Pager { } } }; - // We should only signal that we finished appenind frames after wal sync to avoid inconsistencies when sync fails + + let now = self.io.now(); + tracing::debug!( + "total time flushing cache: {} ms", + now.to_system_time() + .duration_since(self.commit_info.time.get().to_system_time()) + .unwrap() + .as_millis() + ); wal.borrow_mut().finish_append_frames_commit()?; Ok(IOResult::Done(res)) } @@ -2087,6 +2129,7 @@ impl Pager { self.checkpoint_state.replace(CheckpointState::Checkpoint); self.syncing.replace(false); self.commit_info.state.set(CommitState::Start); + self.commit_info.time.set(self.io.now()); self.allocate_page_state.replace(AllocatePageState::Start); self.free_page_state.replace(FreePageState::Start); #[cfg(not(feature = "omit_autovacuum"))] From daea841b47aa4793b4abcebc8ad5bca050b6435f Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sat, 23 Aug 2025 17:03:58 -0400 Subject: [PATCH 40/73] Minor adjustments/comments to wal append_frames_vectored method --- core/storage/pager.rs | 1 - core/storage/wal.rs | 14 ++++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index c3321b382..6290129d6 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -347,7 +347,6 @@ pub enum BtreePageAllocMode { } /// This will keep track of the state of current cache commit in order to not repeat work -#[derive(Clone)] struct CommitInfo { state: Cell, time: Cell, diff --git a/core/storage/wal.rs b/core/storage/wal.rs index ad78dec07..8d28827cd 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1433,7 +1433,7 @@ impl Wal for WalFile { Vec::with_capacity(pages.len()); // Rolling checksum input to each frame build - let mut rolling_csum: (u32, u32) = self.last_checksum; + let mut rolling_checksum: (u32, u32) = self.last_checksum; let mut next_frame_id = self.max_frame + 1; // Build every frame in order, updating the rolling checksum @@ -1451,14 +1451,16 @@ impl Wal for WalFile { }; let frame_db_size = if idx + 1 == pages.len() { + // if it's the final frame we are appending, and the caller included a db_size for the + // commit frame, then we ensure to set it in the header. db_size_on_commit.unwrap_or(0) } else { 0 }; - let (new_csum, frame_bytes) = prepare_wal_frame( + let (new_checksum, frame_bytes) = prepare_wal_frame( &self.buffer_pool, &header, - rolling_csum, + rolling_checksum, shared_page_size, page_id as u32, frame_db_size, @@ -1467,17 +1469,17 @@ impl Wal for WalFile { iovecs.push(frame_bytes); // (page, assigned_frame_id, cumulative_checksum_at_this_frame) - page_frame_and_checksum.push((page.clone(), next_frame_id, new_csum)); + page_frame_and_checksum.push((page.clone(), next_frame_id, new_checksum)); // Advance for the next frame - rolling_csum = new_csum; + rolling_checksum = new_checksum; next_frame_id += 1; } let first_frame_id = self.max_frame + 1; let start_off = self.frame_offset(first_frame_id); - // pre-advance in-memory WAL state like the single-frame path + // pre-advance in-memory WAL state for (page, fid, csum) in &page_frame_and_checksum { self.complete_append_frame(page.get().id as u64, *fid, *csum); } From 37a7ec74777ea1ac516465b883bac5156541cce9 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 25 Aug 2025 09:50:57 -0400 Subject: [PATCH 41/73] Update append_frames_vectored to use new encryption_ctx and apply review --- core/storage/pager.rs | 7 ++++++- core/storage/wal.rs | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 6290129d6..1149acf61 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1291,7 +1291,12 @@ impl Pager { if !pages.is_empty() { let c = wal .borrow_mut() - .append_frames_vectored(pages, page_sz, commit_frame)?; + .append_frames_vectored(pages, page_sz, commit_frame) + .inspect_err(|_| { + for c in completions.iter() { + c.abort(); + } + })?; completions.push(c); } Ok(completions) diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 8d28827cd..15fd04afe 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1442,9 +1442,9 @@ impl Wal for WalFile { let plain = page.get_contents().as_ptr(); let data_to_write: std::borrow::Cow<[u8]> = { - let key = self.encryption_key.borrow(); - if let Some(k) = key.as_ref() { - Cow::Owned(encrypt_page(plain, page_id as usize, k)?) + let ectx = self.encryption_ctx.borrow(); + if let Some(ctx) = ectx.as_ref() { + Cow::Owned(ctx.encrypt_page(plain, page_id as usize)?) } else { Cow::Borrowed(plain) } From 8cae10f744fa0362e2f127227b705bde8216b8e2 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Mon, 25 Aug 2025 16:20:32 +0300 Subject: [PATCH 42/73] Fix several issues with integrity_check Things that were just wrong: 1. No pages other than the root page were checked, because no looping was done. Add a loop. 2. Rightmost child page was never added to page stack. Add it. New integrity check features: - Add overflow pages to stack as well - Check that no page is referenced more than once in the tree --- core/storage/btree.rs | 349 ++++++++++++++++++++++++------------------ 1 file changed, 202 insertions(+), 147 deletions(-) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index cc14856a7..b2c36a560 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -37,7 +37,6 @@ use super::{ write_varint_to_vec, IndexInteriorCell, IndexLeafCell, OverflowCell, MINIMUM_CELL_SIZE, }, }; -#[cfg(debug_assertions)] use std::collections::HashSet; use std::{ cell::{Cell, Ref, RefCell}, @@ -5502,6 +5501,11 @@ pub enum IntegrityCheckError { got: usize, expected: usize, }, + #[error("Page {page_id} referenced multiple times")] + PageReferencedMultipleTimes { + page_id: usize, + is_overflow_page: bool, + }, } #[derive(Clone)] @@ -5514,6 +5518,7 @@ pub struct IntegrityCheckState { pub current_page: usize, page_stack: Vec, first_leaf_level: Option, + pages_referenced: HashSet, page: Option, } @@ -5526,6 +5531,7 @@ impl IntegrityCheckState { level: 0, max_intkey: i64::MAX, }], + pages_referenced: HashSet::new(), first_leaf_level: None, page: None, } @@ -5555,165 +5561,214 @@ pub fn integrity_check( errors: &mut Vec, pager: &Rc, ) -> Result> { - let Some(IntegrityCheckPageEntry { - page_idx, - level, - max_intkey, - }) = state.page_stack.last().cloned() - else { - return Ok(IOResult::Done(())); - }; - let page = match state.page.take() { - Some(page) => page, - None => { - let (page, c) = btree_read_page(pager, page_idx)?; - state.page = Some(page.get()); - if let Some(c) = c { - io_yield_one!(c); + loop { + let Some(IntegrityCheckPageEntry { + page_idx, + level, + max_intkey, + }) = state.page_stack.last().cloned() + else { + return Ok(IOResult::Done(())); + }; + let page = match state.page.take() { + Some(page) => page, + None => { + let (page, c) = btree_read_page(pager, page_idx)?; + state.page = Some(page.get()); + if let Some(c) = c { + io_yield_one!(c); + } + state.page.take().expect("page should be present") } - page.get() - } - }; - turso_assert!(page.is_loaded(), "page should be loaded"); - state.page_stack.pop(); + }; + turso_assert!(page.is_loaded(), "page should be loaded"); + state.page_stack.pop(); - let contents = page.get_contents(); - let usable_space = pager.usable_space(); - let mut coverage_checker = CoverageChecker::new(page.get().id); + let contents = page.get_contents(); + let is_overflow_page = contents.maybe_page_type().is_none(); + if !state.pages_referenced.insert(page.get().id) { + errors.push(IntegrityCheckError::PageReferencedMultipleTimes { + page_id: page.get().id, + is_overflow_page, + }); + continue; + } + let usable_space = pager.usable_space(); + let mut coverage_checker = CoverageChecker::new(page.get().id); - // Now we check every cell for few things: - // 1. Check cell is in correct range. Not exceeds page and not starts before we have marked - // (cell content area). - // 2. We add the cell to coverage checker in order to check if cells do not overlap. - // 3. We check order of rowids in case of table pages. We iterate backwards in order to check - // if current cell's rowid is less than the next cell. We also check rowid is less than the - // parent's divider cell. In case of this page being root page max rowid will be i64::MAX. - // 4. We append pages to the stack to check later. - // 5. In case of leaf page, check if the current level(depth) is equal to other leaf pages we - // have seen. - let mut next_rowid = max_intkey; - for cell_idx in (0..contents.cell_count()).rev() { - let (cell_start, cell_length) = contents.cell_get_raw_region(cell_idx, usable_space); - if cell_start < contents.cell_content_area() as usize || cell_start > usable_space - 4 { - errors.push(IntegrityCheckError::CellOutOfRange { - cell_idx, - page_id: page.get().id, - cell_start, - cell_end: cell_start + cell_length, - content_area: contents.cell_content_area() as usize, - usable_space, - }); - } - if cell_start + cell_length > usable_space { - errors.push(IntegrityCheckError::CellOverflowsPage { - cell_idx, - page_id: page.get().id, - cell_start, - cell_end: cell_start + cell_length, - content_area: contents.cell_content_area() as usize, - usable_space, - }); - } - coverage_checker.add_cell(cell_start, cell_start + cell_length); - let cell = contents.cell_get(cell_idx, usable_space)?; - match cell { - BTreeCell::TableInteriorCell(table_interior_cell) => { + if is_overflow_page { + let next_overflow_page = contents.read_u32_no_offset(0); + if next_overflow_page != 0 { state.page_stack.push(IntegrityCheckPageEntry { - page_idx: table_interior_cell.left_child_page as usize, - level: level + 1, - max_intkey: table_interior_cell.rowid, - }); - let rowid = table_interior_cell.rowid; - if rowid > max_intkey || rowid > next_rowid { - errors.push(IntegrityCheckError::CellRowidOutOfRange { - page_id: page.get().id, - cell_idx, - rowid, - max_intkey, - next_rowid, - }); - } - next_rowid = rowid; - } - BTreeCell::TableLeafCell(table_leaf_cell) => { - // check depth of leaf pages are equal - if let Some(expected_leaf_level) = state.first_leaf_level { - if expected_leaf_level != level { - errors.push(IntegrityCheckError::LeafDepthMismatch { - page_id: page.get().id, - this_page_depth: level, - other_page_depth: expected_leaf_level, - }); - } - } else { - state.first_leaf_level = Some(level); - } - let rowid = table_leaf_cell.rowid; - if rowid > max_intkey || rowid > next_rowid { - errors.push(IntegrityCheckError::CellRowidOutOfRange { - page_id: page.get().id, - cell_idx, - rowid, - max_intkey, - next_rowid, - }); - } - next_rowid = rowid; - } - BTreeCell::IndexInteriorCell(index_interior_cell) => { - state.page_stack.push(IntegrityCheckPageEntry { - page_idx: index_interior_cell.left_child_page as usize, - level: level + 1, - max_intkey, // we don't care about intkey in non-table pages + page_idx: next_overflow_page as usize, + level, + max_intkey, }); } - BTreeCell::IndexLeafCell(_) => { - // check depth of leaf pages are equal - if let Some(expected_leaf_level) = state.first_leaf_level { - if expected_leaf_level != level { - errors.push(IntegrityCheckError::LeafDepthMismatch { - page_id: page.get().id, - this_page_depth: level, - other_page_depth: expected_leaf_level, - }); - } - } else { - state.first_leaf_level = Some(level); - } - } + continue; } - } - // Now we add free blocks to the coverage checker - let first_freeblock = contents.first_freeblock() as usize; - if first_freeblock > 0 { - let mut pc = first_freeblock; - while pc > 0 { - let next = contents.read_u16_no_offset(pc as usize) as usize; - let size = contents.read_u16_no_offset(pc as usize + 2) as usize; - // check it doesn't go out of range - if pc > usable_space - 4 { - errors.push(IntegrityCheckError::FreeBlockOutOfRange { + // Now we check every cell for few things: + // 1. Check cell is in correct range. Not exceeds page and not starts before we have marked + // (cell content area). + // 2. We add the cell to coverage checker in order to check if cells do not overlap. + // 3. We check order of rowids in case of table pages. We iterate backwards in order to check + // if current cell's rowid is less than the next cell. We also check rowid is less than the + // parent's divider cell. In case of this page being root page max rowid will be i64::MAX. + // 4. We append pages to the stack to check later. + // 5. In case of leaf page, check if the current level(depth) is equal to other leaf pages we + // have seen. + let mut next_rowid = max_intkey; + for cell_idx in (0..contents.cell_count()).rev() { + let (cell_start, cell_length) = contents.cell_get_raw_region(cell_idx, usable_space); + if cell_start < contents.cell_content_area() as usize || cell_start > usable_space - 4 { + errors.push(IntegrityCheckError::CellOutOfRange { + cell_idx, page_id: page.get().id, - start: pc, - end: pc + size, + cell_start, + cell_end: cell_start + cell_length, + content_area: contents.cell_content_area() as usize, + usable_space, }); - break; } - coverage_checker.add_free_block(pc, pc + size); - pc = next; + if cell_start + cell_length > usable_space { + errors.push(IntegrityCheckError::CellOverflowsPage { + cell_idx, + page_id: page.get().id, + cell_start, + cell_end: cell_start + cell_length, + content_area: contents.cell_content_area() as usize, + usable_space, + }); + } + coverage_checker.add_cell(cell_start, cell_start + cell_length); + let cell = contents.cell_get(cell_idx, usable_space)?; + match cell { + BTreeCell::TableInteriorCell(table_interior_cell) => { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: table_interior_cell.left_child_page as usize, + level: level + 1, + max_intkey: table_interior_cell.rowid, + }); + let rowid = table_interior_cell.rowid; + if rowid > max_intkey || rowid > next_rowid { + errors.push(IntegrityCheckError::CellRowidOutOfRange { + page_id: page.get().id, + cell_idx, + rowid, + max_intkey, + next_rowid, + }); + } + next_rowid = rowid; + } + BTreeCell::TableLeafCell(table_leaf_cell) => { + // check depth of leaf pages are equal + if let Some(expected_leaf_level) = state.first_leaf_level { + if expected_leaf_level != level { + errors.push(IntegrityCheckError::LeafDepthMismatch { + page_id: page.get().id, + this_page_depth: level, + other_page_depth: expected_leaf_level, + }); + } + } else { + state.first_leaf_level = Some(level); + } + let rowid = table_leaf_cell.rowid; + if rowid > max_intkey || rowid > next_rowid { + errors.push(IntegrityCheckError::CellRowidOutOfRange { + page_id: page.get().id, + cell_idx, + rowid, + max_intkey, + next_rowid, + }); + } + next_rowid = rowid; + if let Some(first_overflow_page) = table_leaf_cell.first_overflow_page { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: first_overflow_page as usize, + level, + max_intkey, + }); + } + } + BTreeCell::IndexInteriorCell(index_interior_cell) => { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: index_interior_cell.left_child_page as usize, + level: level + 1, + max_intkey, // we don't care about intkey in non-table pages + }); + if let Some(first_overflow_page) = index_interior_cell.first_overflow_page { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: first_overflow_page as usize, + level, + max_intkey, + }); + } + } + BTreeCell::IndexLeafCell(index_leaf_cell) => { + // check depth of leaf pages are equal + if let Some(expected_leaf_level) = state.first_leaf_level { + if expected_leaf_level != level { + errors.push(IntegrityCheckError::LeafDepthMismatch { + page_id: page.get().id, + this_page_depth: level, + other_page_depth: expected_leaf_level, + }); + } + } else { + state.first_leaf_level = Some(level); + } + if let Some(first_overflow_page) = index_leaf_cell.first_overflow_page { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: first_overflow_page as usize, + level, + max_intkey, + }); + } + } + } } + + if let Some(rightmost) = contents.rightmost_pointer() { + state.page_stack.push(IntegrityCheckPageEntry { + page_idx: rightmost as usize, + level: level + 1, + max_intkey, + }); + } + + // Now we add free blocks to the coverage checker + let first_freeblock = contents.first_freeblock() as usize; + if first_freeblock > 0 { + let mut pc = first_freeblock; + while pc > 0 { + let next = contents.read_u16_no_offset(pc as usize) as usize; + let size = contents.read_u16_no_offset(pc as usize + 2) as usize; + // check it doesn't go out of range + if pc > usable_space - 4 { + errors.push(IntegrityCheckError::FreeBlockOutOfRange { + page_id: page.get().id, + start: pc, + end: pc + size, + }); + break; + } + coverage_checker.add_free_block(pc, pc + size); + pc = next; + } + } + + // Let's check the overlap of freeblocks and cells now that we have collected them all. + coverage_checker.analyze( + usable_space, + contents.cell_content_area() as usize, + errors, + contents.num_frag_free_bytes() as usize, + ); } - - // Let's check the overlap of freeblocks and cells now that we have collected them all. - coverage_checker.analyze( - usable_space, - contents.cell_content_area() as usize, - errors, - contents.num_frag_free_bytes() as usize, - ); - - Ok(IOResult::Done(())) } pub fn btree_read_page( From 911b4c38a68e4b70c66ffbd41c829b2a76ad0ddd Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 21 Aug 2025 11:21:29 -0500 Subject: [PATCH 43/73] do not ignore silent failures from view creation We have an issue at the moment that when a materialized view fails to be created, we just swallow the error and leave the database in a funny state. We have can_create_view() to detect those issues early, but not all errors can be detected that early. --- core/util.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/core/util.rs b/core/util.rs index 200964a87..9b7d9e2c8 100644 --- a/core/util.rs +++ b/core/util.rs @@ -217,14 +217,13 @@ pub fn parse_schema_rows( if should_create_new { // Create a new IncrementalView - if let Ok(incremental_view) = - IncrementalView::from_sql(sql, schema) - { - let referenced_tables = - incremental_view.get_referenced_table_names(); - schema.add_materialized_view(incremental_view); - views_to_process.push((view_name, referenced_tables)); - } + // If this fails, we should propagate the error so the transaction rolls back + let incremental_view = + IncrementalView::from_sql(sql, schema)?; + let referenced_tables = + incremental_view.get_referenced_table_names(); + schema.add_materialized_view(incremental_view); + views_to_process.push((view_name, referenced_tables)); } } Stmt::CreateView { From 38def267048b1e44048e1aad502c478565d06235 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 21 Aug 2025 10:59:31 -0500 Subject: [PATCH 44/73] Add expr_compiler To be used in DBSP-based projections. This will compile an expression to VDBE bytecode and execute it. To do that we need to add a new type of Expression, which we call a Register. This is a way for us to pass parameters to a DBSP program which will be not columns or literals, but inputs from the DBSP deltas. --- core/incremental/expr_compiler.rs | 441 ++++++++++++++++++ core/incremental/mod.rs | 1 + core/translate/expr.rs | 15 +- core/translate/optimizer/mod.rs | 2 + vendored/sqlite3-parser/src/parser/ast/fmt.rs | 5 + vendored/sqlite3-parser/src/parser/ast/mod.rs | 3 + 6 files changed, 465 insertions(+), 2 deletions(-) create mode 100644 core/incremental/expr_compiler.rs diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs new file mode 100644 index 000000000..310922633 --- /dev/null +++ b/core/incremental/expr_compiler.rs @@ -0,0 +1,441 @@ +// Expression compilation for incremental operators +// This module provides utilities to compile SQL expressions into VDBE subprograms +// that can be executed efficiently in the incremental computation context. + +use crate::schema::Schema; +use crate::storage::pager::Pager; +use crate::translate::emitter::Resolver; +use crate::translate::expr::translate_expr; +use crate::types::Text; +use crate::vdbe::builder::{ProgramBuilder, ProgramBuilderOpts}; +use crate::vdbe::insn::Insn; +use crate::vdbe::{Program, ProgramState, Register}; +use crate::SymbolTable; +use crate::{CaptureDataChangesMode, Connection, QueryMode, Result, Value}; +use std::rc::Rc; +use std::sync::Arc; +use turso_sqlite3_parser::ast::{Expr, Literal, Operator}; + +// Transform an expression to replace column references with Register expressions Why do we want to +// do this? +// +// Imagine you have a view like: +// +// create materialized view hex(count(*) + 2). translate_expr will usually try to find match names +// to either literals or columns. But "count(*)" is not a column in any sqlite table. +// +// We *could* theoretically have a table-representation of every DBSP-step, but it is a lot simpler +// to just pass registers as parameters to the VDBE expression, and teach translate_expr to +// recognize those. +// +// But because the expression compiler will not generate those register inputs, we have to +// transform the expression. +fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { + match expr { + // Transform column references (represented as Id) to Register expressions + Expr::Id(name) => { + // Check if this is a column name from our input + if let Some(idx) = input_column_names + .iter() + .position(|col| col == name.as_str()) + { + // Replace with a Register expression + Expr::Register(idx) + } else { + // Not a column reference, keep as is + expr.clone() + } + } + // Recursively transform nested expressions + Expr::Binary(lhs, op, rhs) => Expr::Binary( + Box::new(transform_expr_for_dbsp(lhs, input_column_names)), + *op, + Box::new(transform_expr_for_dbsp(rhs, input_column_names)), + ), + Expr::Unary(op, operand) => Expr::Unary( + *op, + Box::new(transform_expr_for_dbsp(operand, input_column_names)), + ), + Expr::FunctionCall { + name, + distinctness, + args, + order_by, + filter_over, + } => Expr::FunctionCall { + name: name.clone(), + distinctness: *distinctness, + args: args.as_ref().map(|args_vec| { + args_vec + .iter() + .map(|arg| transform_expr_for_dbsp(arg, input_column_names)) + .collect() + }), + order_by: order_by.clone(), + filter_over: filter_over.clone(), + }, + Expr::Parenthesized(exprs) => Expr::Parenthesized( + exprs + .iter() + .map(|e| transform_expr_for_dbsp(e, input_column_names)) + .collect(), + ), + // For other expression types, keep as is + _ => expr.clone(), + } +} + +/// Enum to represent either a trivial or compiled expression +#[derive(Clone)] +pub enum ExpressionExecutor { + /// Trivial expression that can be evaluated inline + Trivial(TrivialExpression), + /// Compiled VDBE program for complex expressions + Compiled(Arc), +} + +/// Trivial expression that can be evaluated inline without VDBE +/// Only supports operations where operands have the same type (no coercion) +#[derive(Clone, Debug)] +pub enum TrivialExpression { + /// Direct column reference + Column(usize), + /// Immediate value + Immediate(Value), + /// Binary operation on trivial expressions (same-type operands only) + Binary { + left: Box, + op: Operator, + right: Box, + }, +} + +impl TrivialExpression { + /// Evaluate the trivial expression with the given input values + /// Panics if type mismatch occurs (this indicates a bug in validation) + pub fn evaluate(&self, values: &[Value]) -> Value { + match self { + TrivialExpression::Column(idx) => values.get(*idx).cloned().unwrap_or(Value::Null), + TrivialExpression::Immediate(val) => val.clone(), + TrivialExpression::Binary { left, op, right } => { + let left_val = left.evaluate(values); + let right_val = right.evaluate(values); + + // Only perform operations on same-type operands + match op { + Operator::Add => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => Value::Integer(a + b), + (Value::Float(a), Value::Float(b)) => Value::Float(a + b), + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} + {right_val:?}. This is a bug in trivial expression validation."), + }, + Operator::Subtract => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => Value::Integer(a - b), + (Value::Float(a), Value::Float(b)) => Value::Float(a - b), + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} - {right_val:?}. This is a bug in trivial expression validation."), + }, + Operator::Multiply => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => Value::Integer(a * b), + (Value::Float(a), Value::Float(b)) => Value::Float(a * b), + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} * {right_val:?}. This is a bug in trivial expression validation."), + }, + Operator::Divide => match (&left_val, &right_val) { + (Value::Integer(a), Value::Integer(b)) => { + if *b != 0 { + Value::Integer(a / b) + } else { + Value::Null + } + } + (Value::Float(a), Value::Float(b)) => { + if *b != 0.0 { + Value::Float(a / b) + } else { + Value::Null + } + } + (Value::Null, _) | (_, Value::Null) => Value::Null, + _ => panic!("Type mismatch in trivial expression: {left_val:?} / {right_val:?}. This is a bug in trivial expression validation."), + }, + _ => panic!("Unsupported operator in trivial expression: {op:?}"), + } + } + } + } +} + +/// Compiled expression that can be executed on row values +#[derive(Clone)] +pub struct CompiledExpression { + /// The expression executor (trivial or compiled) + pub executor: ExpressionExecutor, + /// Number of input values expected (columns from the row) + pub input_count: usize, +} + +impl std::fmt::Debug for CompiledExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = f.debug_struct("CompiledExpression"); + s.field("input_count", &self.input_count); + match &self.executor { + ExpressionExecutor::Trivial(t) => s.field("executor", &format!("Trivial({t:?})")), + ExpressionExecutor::Compiled(p) => { + s.field("executor", &format!("Compiled({} insns)", p.insns.len())) + } + }; + s.finish() + } +} + +impl CompiledExpression { + /// Get the "type" of a trivial expression for type checking + /// Returns None if type can't be determined statically + fn get_trivial_type(expr: &TrivialExpression) -> Option<&'static str> { + match expr { + TrivialExpression::Column(_) => None, // Can't know column type statically + TrivialExpression::Immediate(val) => match val { + Value::Integer(_) => Some("integer"), + Value::Float(_) => Some("float"), + Value::Text(_) => Some("text"), + Value::Null => Some("null"), + _ => None, + }, + TrivialExpression::Binary { left, right, .. } => { + // For binary ops, both sides must have the same type + let left_type = Self::get_trivial_type(left)?; + let right_type = Self::get_trivial_type(right)?; + if left_type == right_type { + Some(left_type) + } else { + None // Type mismatch + } + } + } + } + + /// Check if an expression is trivial (columns, immediates, and simple arithmetic) + /// Only considers expressions trivial if they don't require type coercion + fn is_trivial_expr(expr: &Expr, input_column_names: &[String]) -> Option { + match expr { + // Column reference or register + Expr::Id(name) => input_column_names + .iter() + .position(|col| col == name.as_str()) + .map(TrivialExpression::Column), + Expr::Register(idx) => Some(TrivialExpression::Column(*idx)), + + // Immediate values + Expr::Literal(lit) => { + let value = match lit { + Literal::Numeric(n) => { + if let Ok(i) = n.parse::() { + Value::Integer(i) + } else if let Ok(f) = n.parse::() { + Value::Float(f) + } else { + return None; + } + } + Literal::String(s) => { + let cleaned = s.trim_matches('\'').trim_matches('"'); + Value::Text(Text::new(cleaned)) + } + Literal::Null => Value::Null, + _ => return None, + }; + Some(TrivialExpression::Immediate(value)) + } + + // Binary operations with simple operators + Expr::Binary(left, op, right) => { + // Only support simple arithmetic operators + match op { + Operator::Add | Operator::Subtract | Operator::Multiply | Operator::Divide => { + // Both operands must be trivial + let left_trivial = Self::is_trivial_expr(left, input_column_names)?; + let right_trivial = Self::is_trivial_expr(right, input_column_names)?; + + // Check if we can determine types statically + // If both are immediates, they must have the same type + // If either is a column, we can't validate at compile time, + // but we'll assert at runtime if there's a mismatch + if let (Some(left_type), Some(right_type)) = ( + Self::get_trivial_type(&left_trivial), + Self::get_trivial_type(&right_trivial), + ) { + // Both types are known - they must match (or one is null) + if left_type != right_type + && left_type != "null" + && right_type != "null" + { + return None; // Type mismatch - not trivial + } + } + // If we can't determine types (columns involved), we optimistically + // assume they'll match at runtime (and assert if they don't) + + Some(TrivialExpression::Binary { + left: Box::new(left_trivial), + op: *op, + right: Box::new(right_trivial), + }) + } + _ => None, + } + } + + // Parenthesized expressions with single element + Expr::Parenthesized(exprs) if exprs.len() == 1 => { + Self::is_trivial_expr(&exprs[0], input_column_names) + } + + _ => None, + } + } + + /// Compile a SQL expression into either a trivial executor or VDBE program + /// + /// For trivial expressions (columns, immediates, simple same-type arithmetic), uses inline evaluation. + /// For complex expressions or those requiring type coercion, compiles to VDBE bytecode. + pub fn compile( + expr: &Expr, + input_column_names: &[String], + schema: &Schema, + syms: &SymbolTable, + connection: Arc, + ) -> Result { + let input_count = input_column_names.len(); + + // First, check if this is a trivial expression + if let Some(trivial) = Self::is_trivial_expr(expr, input_column_names) { + return Ok(CompiledExpression { + executor: ExpressionExecutor::Trivial(trivial), + input_count, + }); + } + + // Fall back to VDBE compilation for complex expressions + // Create a minimal program builder for expression compilation + let mut builder = ProgramBuilder::new( + QueryMode::Normal, + CaptureDataChangesMode::Off, + ProgramBuilderOpts { + num_cursors: 0, + approx_num_insns: 5, // Most expressions are simple + approx_num_labels: 0, // Expressions don't need labels + }, + ); + + // Allocate registers for input values + let input_count = input_column_names.len(); + + // Allocate input registers + for _ in 0..input_count { + builder.alloc_register(); + } + + // Allocate a temp register for computation + let temp_result_register = builder.alloc_register(); + + // Transform the expression to replace column references with Register expressions + let transformed_expr = transform_expr_for_dbsp(expr, input_column_names); + + // Create a resolver for translate_expr + let resolver = Resolver::new(schema, syms); + + // Translate the transformed expression to bytecode + translate_expr( + &mut builder, + None, // No table references needed for pure expressions + &transformed_expr, + temp_result_register, + &resolver, + )?; + + // Copy the result to register 0 for return + builder.emit_insn(Insn::Copy { + src_reg: temp_result_register, + dst_reg: 0, + extra_amount: 0, + }); + + // Add a Halt instruction to complete the subprogram + builder.emit_insn(Insn::Halt { + err_code: 0, + description: String::new(), + }); + + // Build the program from the compiled expression bytecode + let program = Arc::new(builder.build(connection, false, "")); + + Ok(CompiledExpression { + executor: ExpressionExecutor::Compiled(program), + input_count, + }) + } + + /// Execute the compiled expression with the given input values + pub fn execute(&self, values: &[Value], pager: Rc) -> Result { + match &self.executor { + ExpressionExecutor::Trivial(trivial) => { + // Fast path: evaluate trivial expression inline + Ok(trivial.evaluate(values)) + } + ExpressionExecutor::Compiled(program) => { + // Slow path: execute VDBE program + // Create a state with the input values loaded into registers + let mut state = ProgramState::new(program.max_registers, 0); + + // Load input values into registers + for (idx, value) in values.iter().take(self.input_count).enumerate() { + state.set_register(idx, Register::Value(value.clone())); + } + + // Execute the program + let mut pc = 0usize; + while pc < program.insns.len() { + let (insn, insn_fn) = &program.insns[pc]; + state.pc = pc as u32; + + // Execute the instruction + match insn_fn(program, &mut state, insn, &pager, None)? { + crate::vdbe::execute::InsnFunctionStepResult::IO(_) => { + return Err(crate::LimboError::InternalError( + "Expression evaluation encountered unexpected I/O".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Done => { + break; + } + crate::vdbe::execute::InsnFunctionStepResult::Row => { + return Err(crate::LimboError::InternalError( + "Expression evaluation produced unexpected row".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Interrupt => { + return Err(crate::LimboError::InternalError( + "Expression evaluation was interrupted".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Busy => { + return Err(crate::LimboError::InternalError( + "Expression evaluation encountered busy state".to_string(), + )); + } + crate::vdbe::execute::InsnFunctionStepResult::Step => { + pc = state.pc as usize; + } + } + } + + // The compiled expression puts the result in register 0 + match state.get_register(0) { + Register::Value(v) => Ok(v.clone()), + _ => Ok(Value::Null), + } + } + } + } +} diff --git a/core/incremental/mod.rs b/core/incremental/mod.rs index d80a09081..328f1a510 100644 --- a/core/incremental/mod.rs +++ b/core/incremental/mod.rs @@ -1,4 +1,5 @@ pub mod dbsp; +pub mod expr_compiler; pub mod hashable_row; pub mod operator; pub mod view; diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 53111904a..398b31c3c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -2206,6 +2206,15 @@ pub fn translate_expr( }); Ok(target_register) } + ast::Expr::Register(src_reg) => { + // For DBSP expression compilation: copy from source register to target + program.emit_insn(Insn::Copy { + src_reg: *src_reg, + dst_reg: target_register, + extra_amount: 0, + }); + Ok(target_register) + } }?; if let Some(span) = constant_span { @@ -2828,7 +2837,8 @@ where | ast::Expr::DoublyQualified(..) | ast::Expr::Name(_) | ast::Expr::Qualified(..) - | ast::Expr::Variable(_) => { + | ast::Expr::Variable(_) + | ast::Expr::Register(_) => { // No nested expressions } } @@ -3004,7 +3014,8 @@ where | ast::Expr::DoublyQualified(..) | ast::Expr::Name(_) | ast::Expr::Qualified(..) - | ast::Expr::Variable(_) => { + | ast::Expr::Variable(_) + | ast::Expr::Register(_) => { // No nested expressions } } diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index e34b46e91..8502ca005 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -673,6 +673,7 @@ impl Optimizable for ast::Expr { Expr::Subquery(..) => false, Expr::Unary(_, expr) => expr.is_nonnull(tables), Expr::Variable(..) => false, + Expr::Register(..) => false, // Register values can be null } } /// Returns true if the expression is a constant i.e. does not depend on variables or columns etc. @@ -750,6 +751,7 @@ impl Optimizable for ast::Expr { Expr::Subquery(_) => false, Expr::Unary(_, expr) => expr.is_constant(resolver), Expr::Variable(_) => false, + Expr::Register(_) => false, // Register values are not constants } } /// Returns true if the expression is a constant expression that, when evaluated as a condition, is always true or false diff --git a/vendored/sqlite3-parser/src/parser/ast/fmt.rs b/vendored/sqlite3-parser/src/parser/ast/fmt.rs index 64f722421..96013bf35 100644 --- a/vendored/sqlite3-parser/src/parser/ast/fmt.rs +++ b/vendored/sqlite3-parser/src/parser/ast/fmt.rs @@ -890,6 +890,11 @@ impl ToTokens for Expr { Some(_) => s.append(TK_VARIABLE, Some(&("?".to_owned() + var))), None => s.append(TK_VARIABLE, Some("?")), }, + Self::Register(reg) => { + // This is for internal use only, not part of SQL syntax + // Use a special notation that won't conflict with SQL + s.append(TK_VARIABLE, Some(&format!("$r{reg}"))) + } } } } diff --git a/vendored/sqlite3-parser/src/parser/ast/mod.rs b/vendored/sqlite3-parser/src/parser/ast/mod.rs index de096827f..b67c51507 100644 --- a/vendored/sqlite3-parser/src/parser/ast/mod.rs +++ b/vendored/sqlite3-parser/src/parser/ast/mod.rs @@ -365,6 +365,9 @@ pub enum Expr { }, /// binary expression Binary(Box, Operator, Box), + /// Register reference for DBSP expression compilation + /// This is not part of SQL syntax but used internally for incremental computation + Register(usize), /// `CASE` expression Case { /// operand From 097510216e7acd46f7041a394f765141f5241a39 Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 21 Aug 2025 11:09:17 -0500 Subject: [PATCH 45/73] implement the projector operator for DBSP My goal with this patch is to be able to implement the ProjectOperator for DBSP circuits using VDBE for expression evaluation. *not* doing so is dangerous for the following reason: we will end up with different, subtle, and incompatible behavior between SQLite expressions if they are used in views versus outside of views. In fact, even in our prototype had them: our projection tests, which used to pass, were actually wrong =) (sqlite would return something different if those functions were executed outside the view context) For optimization reasons, we single out trivial expressions: they don't have go through VDBE. Trivial expressions are expressions that only involve Columns, Literals, and simple operators on elements of the same type. Even type coercion takes this out of the realm of trivial. Everything that is not trivial, is then translated with translate_expr - in the same way SQLite will, and then compiled with VDBE. We can, over time, make this process much better. There are essentially infinite opportunities for optimization here. But for now, the main warts are: * VDBE execution needs a connection * There is no good way in VDBE to pass parameters to a program. * It is almost trivial to pollute the original connection. For example, we need to issue HALT for the program to stop, but seeing that halt will usually cause the program to try and halt the original program. Subprograms, like the ones we use in triggers are a possible solution, but they are much more expensive to execute, especially given that our execution would essentially have to have a program with no other role than to wrap the subprogram. Therefore, what I am doing is: * There is an in-memory database inside the projection operator (an obvious optimization is to share it with *all* projection operators). * We obtain a connection to that database when the operator is created * We use that connection to execute our VDBE, which offers a clean, safe and isolated way to execute the expression. * We feed the values to the program manually by editing the registers directly. --- core/incremental/operator.rs | 224 ++++++++++++++++++++++++++------ core/incremental/view.rs | 188 +++++---------------------- core/translate/view.rs | 2 +- core/vdbe/mod.rs | 8 ++ testing/materialized_views.test | 19 ++- 5 files changed, 245 insertions(+), 196 deletions(-) diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 740044988..9b0a9d15a 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -2,13 +2,15 @@ // Operator DAG for DBSP-style incremental computation // Based on Feldera DBSP design but adapted for Turso's architecture +use crate::incremental::expr_compiler::CompiledExpression; use crate::incremental::hashable_row::HashableRow; use crate::types::Text; -use crate::Value; +use crate::{Connection, Database, SymbolTable, Value}; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display}; use std::sync::Arc; use std::sync::Mutex; +use turso_sqlite3_parser::ast::*; /// Tracks computation counts to verify incremental behavior (for tests now), and in the future /// should be used to provide statistics. @@ -342,14 +344,13 @@ impl FilterPredicate { } #[derive(Debug, Clone)] -pub enum ProjectColumn { - /// Direct column reference - Column(String), - /// Computed expression - Expression { - expr: Box, - alias: Option, - }, +pub struct ProjectColumn { + /// The original SQL expression (for debugging/fallback) + pub expr: turso_sqlite3_parser::ast::Expr, + /// Optional alias for the column + pub alias: Option, + /// Compiled expression (handles both trivial columns and complex expressions) + pub compiled: CompiledExpression, } #[derive(Debug, Clone)] @@ -584,34 +585,189 @@ impl IncrementalOperator for FilterOperator { } /// Project operator - selects/transforms columns -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ProjectOperator { columns: Vec, input_column_names: Vec, output_column_names: Vec, current_state: Delta, tracker: Option>>, + // Internal in-memory connection for expression evaluation + // Programs are very dependent on having a connection, so give it one. + // + // We could in theory pass the current connection, but there are a host of problems with that. + // For example: during a write transaction, where views are usually updated, we have autocommit + // on. When the program we are executing calls Halt, it will try to commit the current + // transaction, which is absolutely incorrect. + // + // There are other ways to solve this, but a read-only connection to an empty in-memory + // database gives us the closest environment we need to execute expressions. + internal_conn: Arc, +} + +impl std::fmt::Debug for ProjectOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProjectOperator") + .field("columns", &self.columns) + .field("input_column_names", &self.input_column_names) + .field("output_column_names", &self.output_column_names) + .field("current_state", &self.current_state) + .field("tracker", &self.tracker) + .finish_non_exhaustive() + } } impl ProjectOperator { - pub fn new(columns: Vec, input_column_names: Vec) -> Self { + /// Create a new ProjectOperator from a SELECT statement, extracting projection columns + pub fn from_select( + select: &turso_sqlite3_parser::ast::Select, + input_column_names: Vec, + schema: &crate::schema::Schema, + ) -> crate::Result { + use turso_sqlite3_parser::ast::*; + + // Set up internal connection for expression evaluation + let io = Arc::new(crate::MemoryIO::new()); + let db = Database::open_file( + io, ":memory:", false, // no MVCC needed for expression evaluation + false, // no indexes needed + )?; + let internal_conn = db.connect()?; + // Set to read-only mode and disable auto-commit since we're only evaluating expressions + internal_conn.query_only.set(true); + internal_conn.auto_commit.set(false); + + let temp_syms = SymbolTable::new(); + + // Extract columns from SELECT statement + let columns = if let OneSelect::Select(ref select_stmt) = &*select.body.select { + let mut columns = Vec::new(); + for result_col in &select_stmt.columns { + match result_col { + ResultColumn::Expr(expr, alias) => { + let alias_str = if let Some(As::As(alias_name)) = alias { + Some(alias_name.as_str().to_string()) + } else { + None + }; + // Try to compile the expression (handles both columns and complex expressions) + match CompiledExpression::compile( + expr, + &input_column_names, + schema, + &temp_syms, + internal_conn.clone(), + ) { + Ok(compiled) => { + columns.push(ProjectColumn { + expr: expr.clone(), + alias: alias_str, + compiled, + }); + } + Err(_) => { + // If compilation fails, skip this column for now + // In the future we might want to handle this better + } + } + } + ResultColumn::Star => { + // Select all columns - create trivial column references + for name in &input_column_names { + // Create an Id expression for the column + let expr = Expr::Id(Name::Ident(name.clone())); + // This should always compile successfully as a trivial column reference + if let Ok(compiled) = CompiledExpression::compile( + &expr, + &input_column_names, + schema, + &temp_syms, + internal_conn.clone(), + ) { + columns.push(ProjectColumn { + expr, + alias: None, + compiled, + }); + } + } + } + _ => { + // For now, skip TableStar and other cases + } + } + } + + if columns.is_empty() { + // If no columns were extracted, default to projecting all input columns + input_column_names + .iter() + .filter_map(|name| { + let expr = Expr::Id(Name::Ident(name.clone())); + CompiledExpression::compile( + &expr, + &input_column_names, + schema, + &temp_syms, + internal_conn.clone(), + ) + .ok() + .map(|compiled| ProjectColumn { + expr, + alias: None, + compiled, + }) + }) + .collect() + } else { + columns + } + } else { + // Not a simple SELECT statement, default to projecting all columns + input_column_names + .iter() + .filter_map(|name| { + let expr = Expr::Id(Name::Ident(name.clone())); + CompiledExpression::compile( + &expr, + &input_column_names, + schema, + &temp_syms, + internal_conn.clone(), + ) + .ok() + .map(|compiled| ProjectColumn { + expr, + alias: None, + compiled, + }) + }) + .collect() + }; + + // Generate output column names based on aliases or expressions let output_column_names = columns .iter() - .map(|c| match c { - ProjectColumn::Column(name) => name.clone(), - ProjectColumn::Expression { alias, .. } => { - alias.clone().unwrap_or_else(|| "expr".to_string()) - } + .map(|c| { + c.alias.clone().unwrap_or_else(|| { + // For simple column references, use the column name + if let Expr::Id(name) = &c.expr { + name.as_str().to_string() + } else { + "expr".to_string() + } + }) }) .collect(); - Self { + Ok(Self { columns, input_column_names, output_column_names, current_state: Delta::new(), tracker: None, - } + internal_conn, + }) } /// Get the columns for this projection @@ -623,24 +779,19 @@ impl ProjectOperator { let mut output = Vec::new(); for col in &self.columns { - match col { - ProjectColumn::Column(name) => { - if let Some(idx) = self.input_column_names.iter().position(|c| c == name) { - if let Some(v) = values.get(idx) { - output.push(v.clone()); - } else { - output.push(Value::Null); - } - } else { - output.push(Value::Null); - } - } - ProjectColumn::Expression { expr, .. } => { - // Evaluate the expression - let result = self.evaluate_expression(expr, values); - output.push(result); - } - } + // Use the internal connection's pager for expression evaluation + let internal_pager = self.internal_conn.pager.borrow().clone(); + + // Execute the compiled expression (handles both columns and complex expressions) + let result = col + .compiled + .execute(values, internal_pager) + .unwrap_or_else(|_| { + // Fall back to manual evaluation on error + // This can happen for expressions with unsupported operations + self.evaluate_expression(&col.expr, values) + }); + output.push(result); } output @@ -648,7 +799,6 @@ impl ProjectOperator { fn evaluate_expression(&self, expr: &turso_parser::ast::Expr, values: &[Value]) -> Value { use turso_parser::ast::*; - match expr { Expr::Id(name) => { if let Some(idx) = self diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 3db09dea2..360eca7ce 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -1,7 +1,7 @@ use super::dbsp::{RowKeyStream, RowKeyZSet}; use super::operator::{ AggregateFunction, AggregateOperator, ComputationTracker, Delta, FilterOperator, - FilterPredicate, IncrementalOperator, ProjectColumn, ProjectOperator, + FilterPredicate, IncrementalOperator, ProjectOperator, }; use crate::schema::{BTreeTable, Column, Schema}; use crate::types::{IOCompletions, IOResult, Value}; @@ -99,7 +99,6 @@ impl IncrementalView { pub fn can_create_view(select: &ast::Select, schema: &Schema) -> Result<()> { // Check for aggregations let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select); - // Check for JOINs let (join_tables, join_condition) = Self::extract_join_info(select); if join_tables.is_some() || join_condition.is_some() { @@ -108,29 +107,6 @@ impl IncrementalView { )); } - // Check that we have a base table - let base_table_name = Self::extract_base_table(select).ok_or_else(|| { - LimboError::ParseError("views without a base table not supported yet".to_string()) - })?; - - // Get the base table - let base_table = schema.get_btree_table(&base_table_name).ok_or_else(|| { - LimboError::ParseError(format!("Table '{base_table_name}' not found in schema")) - })?; - - // Get base table column names for validation - let base_table_column_names: Vec = base_table - .columns - .iter() - .enumerate() - .map(|(i, col)| col.name.clone().unwrap_or_else(|| format!("column_{i}"))) - .collect(); - - // For non-aggregated views, validate columns are a strict subset - if group_by_columns.is_empty() && aggregate_functions.is_empty() { - Self::validate_view_columns(select, &base_table_column_names)?; - } - Ok(()) } @@ -242,6 +218,7 @@ impl IncrementalView { view_columns, group_by_columns, aggregate_functions, + schema, ) } @@ -256,6 +233,7 @@ impl IncrementalView { columns: Vec, group_by_columns: Vec, aggregate_functions: Vec, + schema: &Schema, ) -> Result { let mut records = BTreeMap::new(); @@ -302,15 +280,11 @@ impl IncrementalView { // Only create project operator for non-aggregated views let project_operator = if !is_aggregated { - let columns = Self::extract_project_columns(&select_stmt, &base_table_column_names) - .unwrap_or_else(|| { - // If we can't extract columns, default to projecting all columns - base_table_column_names - .iter() - .map(|name| ProjectColumn::Column(name.to_string())) - .collect() - }); - let mut proj_op = ProjectOperator::new(columns, base_table_column_names.clone()); + let mut proj_op = ProjectOperator::from_select( + &select_stmt, + base_table_column_names.clone(), + schema, + )?; proj_op.set_tracker(tracker.clone()); Some(proj_op) } else { @@ -347,52 +321,6 @@ impl IncrementalView { vec![self.base_table.clone()] } - /// Validate that view columns are a strict subset of the base table columns - /// No duplicates, no complex expressions, only simple column references - fn validate_view_columns( - select: &ast::Select, - base_table_column_names: &[String], - ) -> Result<()> { - if let ast::OneSelect::Select { ref columns, .. } = select.body.select { - let mut seen_columns = std::collections::HashSet::new(); - - for result_col in columns { - match result_col { - ast::ResultColumn::Expr(expr, _) - if matches!(expr.as_ref(), ast::Expr::Id(_)) => - { - let ast::Expr::Id(name) = expr.as_ref() else { - unreachable!() - }; - let col_name = name.as_str(); - - // Check for duplicates - if !seen_columns.insert(col_name) { - return Err(LimboError::ParseError(format!( - "Duplicate column '{col_name}' in view. Views must have columns as a strict subset of the base table (no duplicates)" - ))); - } - - // Check that column exists in base table - if !base_table_column_names.iter().any(|n| n == col_name) { - return Err(LimboError::ParseError(format!( - "Column '{col_name}' not found in base table. Views must have columns as a strict subset of the base table" - ))); - } - } - ast::ResultColumn::Star => { - // SELECT * is allowed - it's the full set - } - _ => { - // Any other expression is not allowed - return Err(LimboError::ParseError("Complex expressions, functions, or computed columns are not supported in views. Views must have columns as a strict subset of the base table".to_string())); - } - } - } - } - Ok(()) - } - /// Extract the base table name from a SELECT statement (for non-join cases) fn extract_base_table(select: &ast::Select) -> Option { if let ast::OneSelect::Select { @@ -417,16 +345,14 @@ impl IncrementalView { // Get the columns used by the projection operator let mut columns = Vec::new(); for col in project_op.columns() { - match col { - ProjectColumn::Column(name) => { - columns.push(name.clone()); - } - ProjectColumn::Expression { .. } => { - // For expressions, we need all columns (for now) - columns.clear(); - columns.push("*".to_string()); - break; - } + // Check if it's a simple column reference + if let turso_sqlite3_parser::ast::Expr::Id(name) = &col.expr { + columns.push(name.as_str().to_string()); + } else { + // For expressions, we need all columns (for now) + columns.clear(); + columns.push("*".to_string()); + break; } } if columns.is_empty() || columns.contains(&"*".to_string()) { @@ -808,62 +734,6 @@ impl IncrementalView { None } - /// Extract projection columns from SELECT statement - fn extract_project_columns( - select: &ast::Select, - column_names: &[String], - ) -> Option> { - use turso_parser::ast::*; - - if let OneSelect::Select { - columns: ref select_columns, - .. - } = select.body.select - { - let mut columns = Vec::new(); - - for result_col in select_columns { - match result_col { - ResultColumn::Expr(expr, alias) => { - match expr.as_ref() { - Expr::Id(name) => { - // Simple column reference - columns.push(ProjectColumn::Column(name.as_str().to_string())); - } - _ => { - // Expression - store it for evaluation - let alias_str = if let Some(As::As(alias_name)) = alias { - Some(alias_name.as_str().to_string()) - } else { - None - }; - columns.push(ProjectColumn::Expression { - expr: expr.clone(), - alias: alias_str, - }); - } - } - } - ResultColumn::Star => { - // Select all columns - for name in column_names { - columns.push(ProjectColumn::Column(name.as_str().to_string())); - } - } - _ => { - // For now, skip TableStar and other cases - } - } - } - - if !columns.is_empty() { - return Some(columns); - } - } - - None - } - /// Get the current records as an iterator - for cursor-based access pub fn iter(&self) -> impl Iterator)> + '_ { self.stream.to_vec().into_iter().filter_map(move |row| { @@ -927,6 +797,12 @@ impl IncrementalView { // Apply operators in pipeline let mut current_delta = delta.clone(); current_delta = self.apply_filter_to_delta(current_delta); + + // Apply projection operator if present (for non-aggregated views) + if let Some(ref mut project_op) = self.project_operator { + current_delta = project_op.process_delta(current_delta); + } + current_delta = self.apply_aggregation_to_delta(current_delta); // Update records and stream with the processed delta @@ -1083,7 +959,7 @@ mod tests { #[test] fn test_projection_function_call() { let schema = create_test_schema(); - let sql = "CREATE MATERIALIZED VIEW v AS SELECT hex(a) as hex_a, b FROM t"; + let sql = "CREATE MATERIALIZED VIEW v AS SELECT abs(a - 300) as abs_diff, b FROM t"; let view = IncrementalView::from_sql(sql, &schema).unwrap(); @@ -1101,10 +977,8 @@ mod tests { let result = temp_project.get_current_state(); let (output, _weight) = result.changes.first().unwrap(); - assert_eq!( - output.values, - vec![Value::Text("FF".into()), Value::Integer(20),] - ); + // abs(255 - 300) = abs(-45) = 45 + assert_eq!(output.values, vec![Value::Integer(45), Value::Integer(20),]); } #[test] @@ -1214,12 +1088,12 @@ mod tests { assert_eq!( output.values, vec![ - Value::Integer(5), // a - Value::Integer(2), // b - Value::Integer(10), // a * 2 - Value::Integer(6), // b * 3 - Value::Integer(7), // a + b - Value::Text("F".into()), // hex(15) + Value::Integer(5), // a + Value::Integer(2), // b + Value::Integer(10), // a * 2 + Value::Integer(6), // b * 3 + Value::Integer(7), // a + b + Value::Text("3135".into()), // hex(15) - SQLite converts to string "15" then hex encodes ] ); } diff --git a/core/translate/view.rs b/core/translate/view.rs index f2dcf40a8..b339e8961 100644 --- a/core/translate/view.rs +++ b/core/translate/view.rs @@ -96,7 +96,7 @@ pub fn translate_create_materialized_view( // This validation happens before updating sqlite_master to prevent // storing invalid view definitions use crate::incremental::view::IncrementalView; - IncrementalView::can_create_view(select_stmt, schema)?; + IncrementalView::can_create_view(select_stmt)?; // Reconstruct the SQL string let sql = create_materialized_view_to_str(view_name, select_stmt); diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index afb4485ac..1cdd2fb69 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -315,6 +315,14 @@ impl ProgramState { } } + pub fn set_register(&mut self, idx: usize, value: Register) { + self.registers[idx] = value; + } + + pub fn get_register(&self, idx: usize) -> &Register { + &self.registers[idx] + } + pub fn column_count(&self) -> usize { self.registers.len() } diff --git a/testing/materialized_views.test b/testing/materialized_views.test index 36cf63e52..2b6a56be3 100755 --- a/testing/materialized_views.test +++ b/testing/materialized_views.test @@ -367,4 +367,21 @@ do_execsql_test_on_specific_db {:memory:} matview-mixed-operations-sequence { 200|100|1 100|25|1 200|100|1 -300|150|1} \ No newline at end of file +300|150|1} + +do_execsql_test_on_specific_db {:memory:} matview-projections { + CREATE TABLE t(a,b); + + CREATE MATERIALIZED VIEW v AS + SELECT b, a, b + a as c , (b * a) + 10 as d , min(a,b) as e + FROM t + where b > 2; + + INSERT INTO t VALUES (1, 1); + INSERT INTO t VALUES (2, 2); + INSERT INTO t VALUES (3, 4); + INSERT INTO t VALUES (4, 3); + + SELECT * from v; +} {4|3|7|22|3 +3|4|7|22|3} From ffab4a89a2f51f512e525e04b99689062e0bbe4a Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Thu, 21 Aug 2025 16:01:50 -0500 Subject: [PATCH 46/73] addressed review comments from Jussi --- core/incremental/expr_compiler.rs | 48 +++++++++---- core/incremental/operator.rs | 112 +++++++++--------------------- 2 files changed, 67 insertions(+), 93 deletions(-) diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs index 310922633..9ea22028a 100644 --- a/core/incremental/expr_compiler.rs +++ b/core/incremental/expr_compiler.rs @@ -189,17 +189,25 @@ impl std::fmt::Debug for CompiledExpression { } } +#[derive(PartialEq)] +enum TrivialType { + Integer, + Float, + Text, + Null, +} + impl CompiledExpression { /// Get the "type" of a trivial expression for type checking /// Returns None if type can't be determined statically - fn get_trivial_type(expr: &TrivialExpression) -> Option<&'static str> { + fn get_trivial_type(expr: &TrivialExpression) -> Option { match expr { TrivialExpression::Column(_) => None, // Can't know column type statically TrivialExpression::Immediate(val) => match val { - Value::Integer(_) => Some("integer"), - Value::Float(_) => Some("float"), - Value::Text(_) => Some("text"), - Value::Null => Some("null"), + Value::Integer(_) => Some(TrivialType::Integer), + Value::Float(_) => Some(TrivialType::Float), + Value::Text(_) => Some(TrivialType::Text), + Value::Null => Some(TrivialType::Null), _ => None, }, TrivialExpression::Binary { left, right, .. } => { @@ -215,9 +223,12 @@ impl CompiledExpression { } } - /// Check if an expression is trivial (columns, immediates, and simple arithmetic) - /// Only considers expressions trivial if they don't require type coercion - fn is_trivial_expr(expr: &Expr, input_column_names: &[String]) -> Option { + // Validates if an expression is trivial (columns, immediates, and simple arithmetic) + // Only considers expressions trivial if they don't require type coercion + fn try_get_trivial_expr( + expr: &Expr, + input_column_names: &[String], + ) -> Option { match expr { // Column reference or register Expr::Id(name) => input_column_names @@ -254,8 +265,8 @@ impl CompiledExpression { match op { Operator::Add | Operator::Subtract | Operator::Multiply | Operator::Divide => { // Both operands must be trivial - let left_trivial = Self::is_trivial_expr(left, input_column_names)?; - let right_trivial = Self::is_trivial_expr(right, input_column_names)?; + let left_trivial = Self::try_get_trivial_expr(left, input_column_names)?; + let right_trivial = Self::try_get_trivial_expr(right, input_column_names)?; // Check if we can determine types statically // If both are immediates, they must have the same type @@ -267,8 +278,8 @@ impl CompiledExpression { ) { // Both types are known - they must match (or one is null) if left_type != right_type - && left_type != "null" - && right_type != "null" + && left_type != TrivialType::Null + && right_type != TrivialType::Null { return None; // Type mismatch - not trivial } @@ -288,7 +299,7 @@ impl CompiledExpression { // Parenthesized expressions with single element Expr::Parenthesized(exprs) if exprs.len() == 1 => { - Self::is_trivial_expr(&exprs[0], input_column_names) + Self::try_get_trivial_expr(&exprs[0], input_column_names) } _ => None, @@ -309,7 +320,7 @@ impl CompiledExpression { let input_count = input_column_names.len(); // First, check if this is a trivial expression - if let Some(trivial) = Self::is_trivial_expr(expr, input_column_names) { + if let Some(trivial) = Self::try_get_trivial_expr(expr, input_column_names) { return Ok(CompiledExpression { executor: ExpressionExecutor::Trivial(trivial), input_count, @@ -389,7 +400,14 @@ impl CompiledExpression { let mut state = ProgramState::new(program.max_registers, 0); // Load input values into registers - for (idx, value) in values.iter().take(self.input_count).enumerate() { + assert_eq!( + values.len(), + self.input_count, + "Mismatch in number of registers! Got {}, expected {}", + values.len(), + self.input_count + ); + for (idx, value) in values.iter().enumerate() { state.set_register(idx, Register::Value(value.clone())); } diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 9b0a9d15a..34f9b3ff0 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -651,111 +651,71 @@ impl ProjectOperator { None }; // Try to compile the expression (handles both columns and complex expressions) - match CompiledExpression::compile( + let compiled = CompiledExpression::compile( expr, &input_column_names, schema, &temp_syms, internal_conn.clone(), - ) { - Ok(compiled) => { - columns.push(ProjectColumn { - expr: expr.clone(), - alias: alias_str, - compiled, - }); - } - Err(_) => { - // If compilation fails, skip this column for now - // In the future we might want to handle this better - } - } + )?; + columns.push(ProjectColumn { + expr: expr.clone(), + alias: alias_str, + compiled, + }); } ResultColumn::Star => { // Select all columns - create trivial column references for name in &input_column_names { // Create an Id expression for the column let expr = Expr::Id(Name::Ident(name.clone())); - // This should always compile successfully as a trivial column reference - if let Ok(compiled) = CompiledExpression::compile( + let compiled = CompiledExpression::compile( &expr, &input_column_names, schema, &temp_syms, internal_conn.clone(), - ) { - columns.push(ProjectColumn { - expr, - alias: None, - compiled, - }); - } + )?; + columns.push(ProjectColumn { + expr, + alias: None, + compiled, + }); } } - _ => { - // For now, skip TableStar and other cases + x => { + return Err(crate::LimboError::ParseError(format!( + "Unsupported {x:?} clause when compiling project operator", + ))); } } } if columns.is_empty() { - // If no columns were extracted, default to projecting all input columns - input_column_names - .iter() - .filter_map(|name| { - let expr = Expr::Id(Name::Ident(name.clone())); - CompiledExpression::compile( - &expr, - &input_column_names, - schema, - &temp_syms, - internal_conn.clone(), - ) - .ok() - .map(|compiled| ProjectColumn { - expr, - alias: None, - compiled, - }) - }) - .collect() - } else { - columns + return Err(crate::LimboError::ParseError( + "No columns found when compiling project operator".to_string(), + )); } + columns } else { - // Not a simple SELECT statement, default to projecting all columns - input_column_names - .iter() - .filter_map(|name| { - let expr = Expr::Id(Name::Ident(name.clone())); - CompiledExpression::compile( - &expr, - &input_column_names, - schema, - &temp_syms, - internal_conn.clone(), - ) - .ok() - .map(|compiled| ProjectColumn { - expr, - alias: None, - compiled, - }) - }) - .collect() + return Err(crate::LimboError::ParseError( + "Expression is not a valid SELECT expression".to_string(), + )); }; // Generate output column names based on aliases or expressions let output_column_names = columns .iter() .map(|c| { - c.alias.clone().unwrap_or_else(|| { - // For simple column references, use the column name - if let Expr::Id(name) = &c.expr { - name.as_str().to_string() - } else { - "expr".to_string() + c.alias.clone().unwrap_or_else(|| match &c.expr { + Expr::Id(name) => name.as_str().to_string(), + Expr::Qualified(table, column) => { + format!("{}.{}", table.as_str(), column.as_str()) } + Expr::DoublyQualified(db, table, column) => { + format!("{}.{}.{}", db.as_str(), table.as_str(), column.as_str()) + } + _ => c.expr.to_string(), }) }) .collect(); @@ -786,11 +746,7 @@ impl ProjectOperator { let result = col .compiled .execute(values, internal_pager) - .unwrap_or_else(|_| { - // Fall back to manual evaluation on error - // This can happen for expressions with unsupported operations - self.evaluate_expression(&col.expr, values) - }); + .expect("Failed to execute compiled expression for the Project operator"); output.push(result); } From 8eab179a5380672bea819e195a4ba4bfec185162 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sun, 24 Aug 2025 15:11:16 +0300 Subject: [PATCH 47/73] parser/ast: Add Register AST node --- parser/src/ast.rs | 3 +++ parser/src/ast/fmt.rs | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 440685014..a55176304 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -345,6 +345,9 @@ pub enum Expr { }, /// binary expression Binary(Box, Operator, Box), + /// Register reference for DBSP expression compilation + /// This is not part of SQL syntax but used internally for incremental computation + Register(usize), /// `CASE` expression Case { /// operand diff --git a/parser/src/ast/fmt.rs b/parser/src/ast/fmt.rs index 16cf635d2..ee0b6716e 100644 --- a/parser/src/ast/fmt.rs +++ b/parser/src/ast/fmt.rs @@ -708,6 +708,11 @@ impl ToTokens for Expr { op.to_tokens_with_context(s, context)?; rhs.to_tokens_with_context(s, context) } + Self::Register(reg) => { + // This is for internal use only, not part of SQL syntax + // Use a special notation that won't conflict with SQL + s.append(TK_VARIABLE, Some(&format!("$r{reg}"))) + } Self::Case { base, when_then_pairs, From e3ffc82a1de311829e8a1271b58833ccc90dc97d Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sun, 24 Aug 2025 15:12:52 +0300 Subject: [PATCH 48/73] core/incremental: Fix expression compiler to use new parser --- core/incremental/expr_compiler.rs | 14 ++++++-------- core/incremental/operator.rs | 17 ++++++++++------- core/incremental/view.rs | 6 ++---- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs index 9ea22028a..f94d72a2a 100644 --- a/core/incremental/expr_compiler.rs +++ b/core/incremental/expr_compiler.rs @@ -14,7 +14,7 @@ use crate::SymbolTable; use crate::{CaptureDataChangesMode, Connection, QueryMode, Result, Value}; use std::rc::Rc; use std::sync::Arc; -use turso_sqlite3_parser::ast::{Expr, Literal, Operator}; +use turso_parser::ast::{Expr, Literal, Operator}; // Transform an expression to replace column references with Register expressions Why do we want to // do this? @@ -65,19 +65,17 @@ fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { } => Expr::FunctionCall { name: name.clone(), distinctness: *distinctness, - args: args.as_ref().map(|args_vec| { - args_vec - .iter() - .map(|arg| transform_expr_for_dbsp(arg, input_column_names)) - .collect() - }), + args: args + .iter() + .map(|arg| Box::new(transform_expr_for_dbsp(arg, input_column_names))) + .collect(), order_by: order_by.clone(), filter_over: filter_over.clone(), }, Expr::Parenthesized(exprs) => Expr::Parenthesized( exprs .iter() - .map(|e| transform_expr_for_dbsp(e, input_column_names)) + .map(|e| Box::new(transform_expr_for_dbsp(e, input_column_names))) .collect(), ), // For other expression types, keep as is diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 34f9b3ff0..0391e3c0a 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -10,7 +10,6 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display}; use std::sync::Arc; use std::sync::Mutex; -use turso_sqlite3_parser::ast::*; /// Tracks computation counts to verify incremental behavior (for tests now), and in the future /// should be used to provide statistics. @@ -346,7 +345,7 @@ impl FilterPredicate { #[derive(Debug, Clone)] pub struct ProjectColumn { /// The original SQL expression (for debugging/fallback) - pub expr: turso_sqlite3_parser::ast::Expr, + pub expr: turso_parser::ast::Expr, /// Optional alias for the column pub alias: Option, /// Compiled expression (handles both trivial columns and complex expressions) @@ -620,11 +619,11 @@ impl std::fmt::Debug for ProjectOperator { impl ProjectOperator { /// Create a new ProjectOperator from a SELECT statement, extracting projection columns pub fn from_select( - select: &turso_sqlite3_parser::ast::Select, + select: &turso_parser::ast::Select, input_column_names: Vec, schema: &crate::schema::Schema, ) -> crate::Result { - use turso_sqlite3_parser::ast::*; + use turso_parser::ast::*; // Set up internal connection for expression evaluation let io = Arc::new(crate::MemoryIO::new()); @@ -640,9 +639,13 @@ impl ProjectOperator { let temp_syms = SymbolTable::new(); // Extract columns from SELECT statement - let columns = if let OneSelect::Select(ref select_stmt) = &*select.body.select { + let columns = if let OneSelect::Select { + columns: ref select_columns, + .. + } = &select.body.select + { let mut columns = Vec::new(); - for result_col in &select_stmt.columns { + for result_col in select_columns { match result_col { ResultColumn::Expr(expr, alias) => { let alias_str = if let Some(As::As(alias_name)) = alias { @@ -659,7 +662,7 @@ impl ProjectOperator { internal_conn.clone(), )?; columns.push(ProjectColumn { - expr: expr.clone(), + expr: (**expr).clone(), alias: alias_str, compiled, }); diff --git a/core/incremental/view.rs b/core/incremental/view.rs index 360eca7ce..4f4d4c6e6 100644 --- a/core/incremental/view.rs +++ b/core/incremental/view.rs @@ -96,9 +96,7 @@ pub struct IncrementalView { impl IncrementalView { /// Validate that a CREATE MATERIALIZED VIEW statement can be handled by IncrementalView /// This should be called early, before updating sqlite_master - pub fn can_create_view(select: &ast::Select, schema: &Schema) -> Result<()> { - // Check for aggregations - let (group_by_columns, aggregate_functions, _) = Self::extract_aggregation_info(select); + pub fn can_create_view(select: &ast::Select) -> Result<()> { // Check for JOINs let (join_tables, join_condition) = Self::extract_join_info(select); if join_tables.is_some() || join_condition.is_some() { @@ -346,7 +344,7 @@ impl IncrementalView { let mut columns = Vec::new(); for col in project_op.columns() { // Check if it's a simple column reference - if let turso_sqlite3_parser::ast::Expr::Id(name) = &col.expr { + if let turso_parser::ast::Expr::Id(name) = &col.expr { columns.push(name.as_str().to_string()); } else { // For expressions, we need all columns (for now) From 9f6468ec82e70d279250f6d9b9c24ed248eae5de Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Aug 2025 17:51:07 +0300 Subject: [PATCH 49/73] sqlite3: Implement sqlite3_malloc() and sqlite3_free() --- sqlite3/src/lib.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sqlite3/src/lib.rs b/sqlite3/src/lib.rs index 46d6d64b3..1e19c9eee 100644 --- a/sqlite3/src/lib.rs +++ b/sqlite3/src/lib.rs @@ -461,13 +461,24 @@ pub unsafe extern "C" fn sqlite3_limit( } #[no_mangle] -pub unsafe extern "C" fn sqlite3_malloc64(_n: ffi::c_int) -> *mut ffi::c_void { - stub!(); +pub unsafe extern "C" fn sqlite3_malloc(n: ffi::c_int) -> *mut ffi::c_void { + sqlite3_malloc64(n) } #[no_mangle] -pub unsafe extern "C" fn sqlite3_free(_ptr: *mut ffi::c_void) { - stub!(); +pub unsafe extern "C" fn sqlite3_malloc64(n: ffi::c_int) -> *mut ffi::c_void { + if n <= 0 { + return std::ptr::null_mut(); + } + libc::malloc(n as usize) +} + +#[no_mangle] +pub unsafe extern "C" fn sqlite3_free(ptr: *mut ffi::c_void) { + if ptr.is_null() { + return; + } + libc::free(ptr); } /// Returns the error code for the most recent failed API call to connection. From 1b514e6d0fce8b362ab61846df84ec2ab500af7e Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 24 Aug 2025 11:09:24 -0400 Subject: [PATCH 50/73] Only checkpoint final remaining DB connection, and use Truncate mode --- core/lib.rs | 37 +++++++++++++++++++++++++++++++++---- core/storage/pager.rs | 26 +++++++++++++++++++++----- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index 7f35f8861..aeffb58c5 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -71,7 +71,7 @@ use std::{ num::NonZero, ops::Deref, rc::Rc, - sync::{Arc, LazyLock, Mutex, Weak}, + sync::{atomic::AtomicUsize, Arc, LazyLock, Mutex, Weak}, }; #[cfg(feature = "fs")] use storage::database::DatabaseFile; @@ -137,6 +137,7 @@ pub struct Database { open_flags: OpenFlags, builtin_syms: RefCell, experimental_views: bool, + n_connections: AtomicUsize, } unsafe impl Send for Database {} @@ -185,6 +186,12 @@ impl fmt::Debug for Database { }; debug_struct.field("page_cache", &cache_info); + debug_struct.field( + "n_connections", + &self + .n_connections + .load(std::sync::atomic::Ordering::Relaxed), + ); debug_struct.finish() } } @@ -372,6 +379,7 @@ impl Database { init_lock: Arc::new(Mutex::new(())), experimental_views: enable_views, buffer_pool: BufferPool::begin_init(&io, arena_size), + n_connections: AtomicUsize::new(0), }); db.register_global_builtin_extensions() .expect("unable to register global extensions"); @@ -425,6 +433,8 @@ impl Database { .unwrap_or_default() .get(); + self.n_connections + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let conn = Arc::new(Connection { _db: self.clone(), pager: RefCell::new(Rc::new(pager)), @@ -888,6 +898,17 @@ pub struct Connection { encryption_key: RefCell>, } +impl Drop for Connection { + fn drop(&mut self) { + if !self.closed.get() { + // if connection wasn't properly closed, decrement the connection counter + self._db + .n_connections + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + } +} + impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn prepare(self: &Arc, sql: impl AsRef) -> Result { @@ -1506,9 +1527,17 @@ impl Connection { } } - self.pager - .borrow() - .checkpoint_shutdown(self.wal_auto_checkpoint_disabled.get()) + if self + ._db + .n_connections + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed) + .eq(&1) + { + self.pager + .borrow() + .checkpoint_shutdown(self.wal_auto_checkpoint_disabled.get()); + }; + Ok(()) } pub fn wal_auto_checkpoint_disable(&self) { diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 1149acf61..cc528cd74 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1555,8 +1555,18 @@ impl Pager { .expect("Failed to clear page cache"); } + /// Checkpoint in Truncate mode and delete the WAL file. This method is _only_ to be called + /// for shutting down the last remaining connection to a database. + /// + /// sqlite3.h + /// Usually, when a database in [WAL mode] is closed or detached from a + /// database handle, SQLite checks if if there are other connections to the + /// same database, and if there are no other database connection (if the + /// connection being closed is the last open connection to the database), + /// then SQLite performs a [checkpoint] before closing the connection and + /// deletes the WAL file. pub fn checkpoint_shutdown(&self, wal_auto_checkpoint_disabled: bool) -> Result<()> { - let mut _attempts = 0; + let mut attempts = 0; { let Some(wal) = self.wal.as_ref() else { return Err(LimboError::InternalError( @@ -1565,16 +1575,22 @@ impl Pager { }; let mut wal = wal.borrow_mut(); // fsync the wal syncronously before beginning checkpoint - // TODO: for now forget about timeouts as they fail regularly in SIM - // need to think of a better way to do this let c = wal.sync()?; self.io.wait_for_completion(c)?; } if !wal_auto_checkpoint_disabled { - self.wal_checkpoint(CheckpointMode::Passive { + while let Err(LimboError::Busy) = self.wal_checkpoint(CheckpointMode::Truncate { upper_bound_inclusive: None, - })?; + }) { + if attempts == 3 { + // don't return error on `close` if we are unable to checkpoint, we can + // silently fail + return Ok(()); + } + attempts += 1; + } } + // TODO: delete the WAL file here after truncate checkpoint Ok(()) } From 748e339f68a4b1bc259a7b84ba70c0488bc9451a Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 24 Aug 2025 12:07:42 -0400 Subject: [PATCH 51/73] Make clippy happy --- core/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/lib.rs b/core/lib.rs index aeffb58c5..f9f76fe91 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -1535,7 +1535,7 @@ impl Connection { { self.pager .borrow() - .checkpoint_shutdown(self.wal_auto_checkpoint_disabled.get()); + .checkpoint_shutdown(self.wal_auto_checkpoint_disabled.get())?; }; Ok(()) } From 2d661e3304a3ded1678b7a289720b2b47d994aac Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 25 Aug 2025 16:56:43 -0400 Subject: [PATCH 52/73] Apply review suggestions, add logging --- core/lib.rs | 4 ++-- core/storage/pager.rs | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index f9f76fe91..4ea22bfd1 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -433,8 +433,6 @@ impl Database { .unwrap_or_default() .get(); - self.n_connections - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let conn = Arc::new(Connection { _db: self.clone(), pager: RefCell::new(Rc::new(pager)), @@ -466,6 +464,8 @@ impl Database { is_nested_stmt: Cell::new(false), encryption_key: RefCell::new(None), }); + self.n_connections + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let builtin_syms = self.builtin_syms.borrow(); // add built-in extensions symbols to the connection to prevent having to load each time conn.syms.borrow_mut().extend(&builtin_syms); diff --git a/core/storage/pager.rs b/core/storage/pager.rs index cc528cd74..f4ba0fc5f 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1583,14 +1583,17 @@ impl Pager { upper_bound_inclusive: None, }) { if attempts == 3 { - // don't return error on `close` if we are unable to checkpoint, we can - // silently fail + // don't return error on `close` if we are unable to checkpoint, we can silently fail + tracing::warn!( + "Failed to checkpoint WAL on shutdown after 3 attempts, giving up" + ); return Ok(()); } attempts += 1; } } - // TODO: delete the WAL file here after truncate checkpoint + // TODO: delete the WAL file here after truncate checkpoint, but *only* if we are sure that + // no other connections have opened since. Ok(()) } From 177c717f2565e99afd406122b62700b318f52023 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 25 Aug 2025 18:47:21 -0400 Subject: [PATCH 53/73] Remove windows IO in place of Generic IO --- core/io/mod.rs | 8 +--- core/io/windows.rs | 115 --------------------------------------------- 2 files changed, 1 insertion(+), 122 deletions(-) delete mode 100644 core/io/windows.rs diff --git a/core/io/mod.rs b/core/io/mod.rs index 1fded463e..992eabac0 100644 --- a/core/io/mod.rs +++ b/core/io/mod.rs @@ -506,13 +506,7 @@ cfg_block! { pub use PlatformIO as SyscallIO; } - #[cfg(target_os = "windows")] { - mod windows; - pub use windows::WindowsIO as PlatformIO; - pub use PlatformIO as SyscallIO; - } - - #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "android", target_os = "ios")))] { + #[cfg(not(any(target_family = "unix", target_os = "android", target_os = "ios")))] { mod generic; pub use generic::GenericIO as PlatformIO; pub use PlatformIO as SyscallIO; diff --git a/core/io/windows.rs b/core/io/windows.rs deleted file mode 100644 index acb12b344..000000000 --- a/core/io/windows.rs +++ /dev/null @@ -1,115 +0,0 @@ -use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO}; -use parking_lot::RwLock; -use std::io::{Read, Seek, Write}; -use std::sync::Arc; -use tracing::{debug, instrument, trace, Level}; -pub struct WindowsIO {} - -impl WindowsIO { - pub fn new() -> Result { - debug!("Using IO backend 'syscall'"); - Ok(Self {}) - } -} - -impl IO for WindowsIO { - #[instrument(err, skip_all, level = Level::TRACE)] - fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result> { - trace!("open_file(path = {})", path); - let mut file = std::fs::File::options(); - file.read(true); - - if !flags.contains(OpenFlags::ReadOnly) { - file.write(true); - file.create(flags.contains(OpenFlags::Create)); - } - - let file = file.open(path)?; - Ok(Arc::new(WindowsFile { - file: RwLock::new(file), - })) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn remove_file(&self, path: &str) -> Result<()> { - trace!("remove_file(path = {})", path); - Ok(std::fs::remove_file(path)?) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn run_once(&self) -> Result<()> { - Ok(()) - } -} - -impl Clock for WindowsIO { - fn now(&self) -> Instant { - let now = chrono::Local::now(); - Instant { - secs: now.timestamp(), - micros: now.timestamp_subsec_micros(), - } - } -} - -pub struct WindowsFile { - file: RwLock, -} - -impl File for WindowsFile { - #[instrument(err, skip_all, level = Level::TRACE)] - fn lock_file(&self, exclusive: bool) -> Result<()> { - unimplemented!() - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn unlock_file(&self) -> Result<()> { - unimplemented!() - } - - #[instrument(skip(self, c), level = Level::TRACE)] - fn pread(&self, pos: usize, c: Completion) -> Result { - let mut file = self.file.write(); - file.seek(std::io::SeekFrom::Start(pos as u64))?; - let nr = { - let r = c.as_read(); - let buf = r.buf(); - let buf = buf.as_mut_slice(); - file.read_exact(buf)?; - buf.len() as i32 - }; - c.complete(nr); - Ok(c) - } - - #[instrument(skip(self, c, buffer), level = Level::TRACE)] - fn pwrite(&self, pos: usize, buffer: Arc, c: Completion) -> Result { - let mut file = self.file.write(); - file.seek(std::io::SeekFrom::Start(pos as u64))?; - let buf = buffer.as_slice(); - file.write_all(buf)?; - c.complete(buffer.len() as i32); - Ok(c) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn sync(&self, c: Completion) -> Result { - let file = self.file.write(); - file.sync_all()?; - c.complete(0); - Ok(c) - } - - #[instrument(err, skip_all, level = Level::TRACE)] - fn truncate(&self, len: usize, c: Completion) -> Result { - let file = self.file.write(); - file.set_len(len as u64)?; - c.complete(0); - Ok(c) - } - - fn size(&self) -> Result { - let file = self.file.read(); - Ok(file.metadata().unwrap().len()) - } -} From 5108c72a28802d202145364c3fe967ac714ad821 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 15:53:45 -0300 Subject: [PATCH 54/73] remove box from `Vec>` --- parser/src/ast.rs | 134 ++++++++++- parser/src/parser.rs | 551 +++++++++++++++++-------------------------- 2 files changed, 347 insertions(+), 338 deletions(-) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index a55176304..5626ffbaa 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -285,6 +285,23 @@ pub enum Stmt { }, } +impl Stmt { + pub fn attach(expr: Expr, db_name: Expr, key: Option) -> Stmt { + Stmt::Attach { + expr: Box::new(expr), + db_name: Box::new(db_name), + key: key.map(Box::new), + } + } + + pub fn vacuum(name: Option, into: Option) -> Stmt { + Stmt::Vacuum { + name, + into: into.map(Box::new), + } + } +} + #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -353,7 +370,7 @@ pub enum Expr { /// operand base: Option>, /// `WHEN` condition `THEN` result - when_then_pairs: Vec<(Box, Box)>, + when_then_pairs: Vec<(Expr, Expr)>, /// `ELSE` result else_expr: Option>, }, @@ -377,7 +394,7 @@ pub enum Expr { /// `DISTINCT` distinctness: Option, /// arguments - args: Vec>, + args: Vec, /// `ORDER BY` order_by: Vec, /// `FILTER` @@ -417,7 +434,7 @@ pub enum Expr { /// `NOT` not: bool, /// values - rhs: Vec>, + rhs: Vec, }, /// `IN` subselect InSelect { @@ -437,7 +454,7 @@ pub enum Expr { /// table name rhs: QualifiedName, /// table function arguments - args: Vec>, + args: Vec, }, /// `IS NULL` IsNull(Box), @@ -461,7 +478,7 @@ pub enum Expr { /// `NOT NULL` or `NOTNULL` NotNull(Box), /// Parenthesized subexpression - Parenthesized(Vec>), + Parenthesized(Vec), /// Qualified name Qualified(Name, Name), /// `RAISE` function call @@ -474,6 +491,105 @@ pub enum Expr { Variable(String), } +impl Expr { + pub fn into_boxed(self) -> Box { + Box::new(self) + } + + pub fn unary(operator: UnaryOperator, expr: Expr) -> Expr { + Expr::Unary(operator, Box::new(expr)) + } + + pub fn binary(lhs: Expr, operator: Operator, rhs: Expr) -> Expr { + Expr::Binary(Box::new(lhs), operator, Box::new(rhs)) + } + + pub fn not_null(expr: Expr) -> Expr { + Expr::NotNull(Box::new(expr)) + } + + pub fn between(lhs: Expr, not: bool, start: Expr, end: Expr) -> Expr { + Expr::Between { + lhs: Box::new(lhs), + not, + start: Box::new(start), + end: Box::new(end), + } + } + + pub fn in_select(lhs: Expr, not: bool, select: Select) -> Expr { + Expr::InSelect { + lhs: Box::new(lhs), + not, + rhs: select, + } + } + + pub fn in_list(lhs: Expr, not: bool, rhs: Vec) -> Expr { + Expr::InList { + lhs: Box::new(lhs), + not, + rhs, + } + } + + pub fn in_table(lhs: Expr, not: bool, rhs: QualifiedName, args: Vec) -> Expr { + Expr::InTable { + lhs: Box::new(lhs), + not, + rhs, + args, + } + } + + pub fn like( + lhs: Expr, + not: bool, + operator: LikeOperator, + rhs: Expr, + escape: Option, + ) -> Expr { + Expr::Like { + lhs: Box::new(lhs), + not, + op: operator, + rhs: Box::new(rhs), + escape: escape.map(Box::new), + } + } + + pub fn is_null(expr: Expr) -> Expr { + Expr::IsNull(Box::new(expr)) + } + + pub fn collate(expr: Expr, name: Name) -> Expr { + Expr::Collate(Box::new(expr), name) + } + + pub fn cast(expr: Expr, type_name: Option) -> Expr { + Expr::Cast { + expr: Box::new(expr), + type_name, + } + } + + pub fn case( + base: Option, + when_then_pairs: Vec<(Expr, Expr)>, + else_expr: Option, + ) -> Expr { + Expr::Case { + base: base.map(Box::new), + when_then_pairs, + else_expr: else_expr.map(Box::new), + } + } + + pub fn raise(resolve_type: ResolveType, expr: Option) -> Expr { + Expr::Raise(resolve_type, expr.map(Box::new)) + } +} + /// SQL literal #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -680,7 +796,7 @@ pub enum OneSelect { window_clause: Vec, }, /// `VALUES` - Values(Vec>>), + Values(Vec>), } /// `SELECT` ... `FROM` clause @@ -748,7 +864,7 @@ pub enum SelectTable { /// table Table(QualifiedName, Option, Option), /// table function call - TableCall(QualifiedName, Vec>, Option), + TableCall(QualifiedName, Vec, Option), /// `SELECT` subquery Select(Select, Option), /// subquery @@ -802,7 +918,7 @@ pub enum JoinConstraint { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct GroupBy { /// expressions - pub exprs: Vec>, + pub exprs: Vec, /// `HAVING` pub having: Option>, // HAVING clause on a non-aggregate query } @@ -1505,7 +1621,7 @@ pub struct Window { /// base window name pub base: Option, /// `PARTITION BY` - pub partition_by: Vec>, + pub partition_by: Vec, /// `ORDER BY` pub order_by: Vec, /// frame spec diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 124058fc9..c21a8fd7b 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -953,11 +953,11 @@ impl<'a> Parser<'a> { } } - fn parse_signed(&mut self) -> Result> { + fn parse_signed(&mut self) -> Result { peek_expect!(self, TK_FLOAT, TK_INTEGER, TK_PLUS, TK_MINUS); let expr = self.parse_expr_operand()?; - match expr.as_ref() { + match &expr { Expr::Unary(_, inner) => match inner.as_ref() { Expr::Literal(Literal::Numeric(_)) => Ok(expr), _ => Err(Error::Custom( @@ -998,11 +998,14 @@ impl<'a> Parser<'a> { let first_size = self.parse_signed()?; let tok = eat_expect!(self, TK_RP, TK_COMMA); match tok.token_type.unwrap() { - TK_RP => Some(TypeSize::MaxSize(first_size)), + TK_RP => Some(TypeSize::MaxSize(Box::new(first_size))), TK_COMMA => { let second_size = self.parse_signed()?; eat_expect!(self, TK_RP); - Some(TypeSize::TypeSize(first_size, second_size)) + Some(TypeSize::TypeSize( + Box::new(first_size), + Box::new(second_size), + )) } _ => unreachable!(), } @@ -1098,7 +1101,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_WHERE); let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); - Ok(Some(expr)) + Ok(Some(Box::new(expr))) } fn parse_frame_opt(&mut self) -> Result> { @@ -1141,7 +1144,7 @@ impl<'a> Parser<'a> { FrameBound::CurrentRow } _ => { - let expr = self.parse_expr(0)?; + let expr = Box::new(self.parse_expr(0)?); let tok = eat_expect!(self, TK_PRECEDING, TK_FOLLOWING); match tok.token_type.unwrap() { TK_PRECEDING => FrameBound::Preceding(expr), @@ -1166,7 +1169,7 @@ impl<'a> Parser<'a> { FrameBound::CurrentRow } _ => { - let expr = self.parse_expr(0)?; + let expr = Box::new(self.parse_expr(0)?); let tok = eat_expect!(self, TK_PRECEDING, TK_FOLLOWING); match tok.token_type.unwrap() { TK_PRECEDING => FrameBound::Preceding(expr), @@ -1285,7 +1288,7 @@ impl<'a> Parser<'a> { } } - fn parse_expr_operand(&mut self) -> Result> { + fn parse_expr_operand(&mut self) -> Result { let tok = peek_expect!( self, TK_LP, @@ -1316,40 +1319,34 @@ impl<'a> Parser<'a> { TK_WITH | TK_SELECT | TK_VALUES => { let select = self.parse_select()?; eat_expect!(self, TK_RP); - Ok(Box::new(Expr::Subquery(select))) + Ok(Expr::Subquery(select)) } _ => { let exprs = self.parse_nexpr_list()?; eat_expect!(self, TK_RP); - Ok(Box::new(Expr::Parenthesized(exprs))) + Ok(Expr::Parenthesized(exprs)) } } } TK_NULL => { eat_assert!(self, TK_NULL); - Ok(Box::new(Expr::Literal(Literal::Null))) + Ok(Expr::Literal(Literal::Null)) } TK_BLOB => { let tok = eat_assert!(self, TK_BLOB); - Ok(Box::new(Expr::Literal(Literal::Blob(from_bytes( - tok.value, - ))))) + Ok(Expr::Literal(Literal::Blob(from_bytes(tok.value)))) } TK_FLOAT => { let tok = eat_assert!(self, TK_FLOAT); - Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes( - tok.value, - ))))) + Ok(Expr::Literal(Literal::Numeric(from_bytes(tok.value)))) } TK_INTEGER => { let tok = eat_assert!(self, TK_INTEGER); - Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes( - tok.value, - ))))) + Ok(Expr::Literal(Literal::Numeric(from_bytes(tok.value)))) } TK_VARIABLE => { let tok = eat_assert!(self, TK_VARIABLE); - Ok(Box::new(Expr::Variable(from_bytes(tok.value)))) + Ok(Expr::Variable(from_bytes(tok.value))) } TK_CAST => { eat_assert!(self, TK_CAST); @@ -1358,19 +1355,16 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_AS); let typ = self.parse_type()?; eat_expect!(self, TK_RP); - Ok(Box::new(Expr::Cast { - expr, - type_name: typ, - })) + Ok(Expr::cast(expr, typ)) } TK_CTIME_KW => { let tok = eat_assert!(self, TK_CTIME_KW); if b"CURRENT_DATE".eq_ignore_ascii_case(tok.value) { - Ok(Box::new(Expr::Literal(Literal::CurrentDate))) + Ok(Expr::Literal(Literal::CurrentDate)) } else if b"CURRENT_TIME".eq_ignore_ascii_case(tok.value) { - Ok(Box::new(Expr::Literal(Literal::CurrentTime))) + Ok(Expr::Literal(Literal::CurrentTime)) } else if b"CURRENT_TIMESTAMP".eq_ignore_ascii_case(tok.value) { - Ok(Box::new(Expr::Literal(Literal::CurrentTimestamp))) + Ok(Expr::Literal(Literal::CurrentTimestamp)) } else { unreachable!() } @@ -1378,29 +1372,29 @@ impl<'a> Parser<'a> { TK_NOT => { eat_assert!(self, TK_NOT); let expr = self.parse_expr(2)?; // NOT precedence is 2 - Ok(Box::new(Expr::Unary(UnaryOperator::Not, expr))) + Ok(Expr::unary(UnaryOperator::Not, expr)) } TK_BITNOT => { eat_assert!(self, TK_BITNOT); let expr = self.parse_expr(11)?; // BITNOT precedence is 11 - Ok(Box::new(Expr::Unary(UnaryOperator::BitwiseNot, expr))) + Ok(Expr::unary(UnaryOperator::BitwiseNot, expr)) } TK_PLUS => { eat_assert!(self, TK_PLUS); let expr = self.parse_expr(11)?; // PLUS precedence is 11 - Ok(Box::new(Expr::Unary(UnaryOperator::Positive, expr))) + Ok(Expr::unary(UnaryOperator::Positive, expr)) } TK_MINUS => { eat_assert!(self, TK_MINUS); let expr = self.parse_expr(11)?; // MINUS precedence is 11 - Ok(Box::new(Expr::Unary(UnaryOperator::Negative, expr))) + Ok(Expr::unary(UnaryOperator::Negative, expr)) } TK_EXISTS => { eat_assert!(self, TK_EXISTS); eat_expect!(self, TK_LP); let select = self.parse_select()?; eat_expect!(self, TK_RP); - Ok(Box::new(Expr::Exists(select))) + Ok(Expr::Exists(select)) } TK_CASE => { eat_assert!(self, TK_CASE); @@ -1439,11 +1433,7 @@ impl<'a> Parser<'a> { }; eat_expect!(self, TK_END); - Ok(Box::new(Expr::Case { - base, - when_then_pairs, - else_expr, - })) + Ok(Expr::case(base, when_then_pairs, else_expr)) } TK_RAISE => { eat_assert!(self, TK_RAISE); @@ -1465,7 +1455,7 @@ impl<'a> Parser<'a> { }; eat_expect!(self, TK_RP); - Ok(Box::new(Expr::Raise(resolve, expr))) + Ok(Expr::raise(resolve, expr)) } _ => { let can_be_lit_str = tok.token_type == Some(TK_STRING); @@ -1490,10 +1480,10 @@ impl<'a> Parser<'a> { TK_STAR => { eat_assert!(self, TK_STAR); eat_expect!(self, TK_RP); - return Ok(Box::new(Expr::FunctionCallStar { + return Ok(Expr::FunctionCallStar { name, filter_over: self.parse_filter_over()?, - })); + }); } _ => { let distinct = self.parse_distinct()?; @@ -1501,13 +1491,13 @@ impl<'a> Parser<'a> { let order_by = self.parse_order_by()?; eat_expect!(self, TK_RP); let filter_over = self.parse_filter_over()?; - return Ok(Box::new(Expr::FunctionCall { + return Ok(Expr::FunctionCall { name, distinctness: distinct, args: exprs, order_by, filter_over, - })); + }); } } } else { @@ -1531,28 +1521,24 @@ impl<'a> Parser<'a> { if let Some(second_name) = second_name { if let Some(third_name) = third_name { - Ok(Box::new(Expr::DoublyQualified( - name, - second_name, - third_name, - ))) + Ok(Expr::DoublyQualified(name, second_name, third_name)) } else { - Ok(Box::new(Expr::Qualified(name, second_name))) + Ok(Expr::Qualified(name, second_name)) } } else if can_be_lit_str { - Ok(Box::new(Expr::Literal(match name { + Ok(Expr::Literal(match name { Name::Quoted(s) => Literal::String(s), Name::Ident(s) => Literal::String(s), - }))) + })) } else { - Ok(Box::new(Expr::Id(name))) + Ok(Expr::Id(name)) } } } } #[allow(clippy::vec_box)] - fn parse_expr_list(&mut self) -> Result>> { + fn parse_expr_list(&mut self) -> Result> { let mut exprs = vec![]; while let Some(tok) = self.peek()? { match tok.token_type.unwrap().fallback_id_if_ok() { @@ -1574,7 +1560,7 @@ impl<'a> Parser<'a> { Ok(exprs) } - fn parse_expr(&mut self, precedence: u8) -> Result> { + fn parse_expr(&mut self, precedence: u8) -> Result { let mut result = self.parse_expr_operand()?; loop { @@ -1594,44 +1580,28 @@ impl<'a> Parser<'a> { not = true; } - result = match tok.token_type.unwrap() { + let expr = match tok.token_type.unwrap() { TK_NULL => { // special case `NOT NULL` debug_assert!(not); // FIXME: not always true because of current_token_precedence eat_assert!(self, TK_NULL); - Box::new(Expr::NotNull(result)) + Expr::not_null(result) } TK_OR => { eat_assert!(self, TK_OR); - Box::new(Expr::Binary( - result, - Operator::Or, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Or, self.parse_expr(pre + 1)?) } TK_AND => { eat_assert!(self, TK_AND); - Box::new(Expr::Binary( - result, - Operator::And, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::And, self.parse_expr(pre + 1)?) } TK_EQ => { eat_assert!(self, TK_EQ); - Box::new(Expr::Binary( - result, - Operator::Equals, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Equals, self.parse_expr(pre + 1)?) } TK_NE => { eat_assert!(self, TK_NE); - Box::new(Expr::Binary( - result, - Operator::NotEquals, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::NotEquals, self.parse_expr(pre + 1)?) } TK_IS => { eat_assert!(self, TK_IS); @@ -1663,19 +1633,14 @@ impl<'a> Parser<'a> { } }; - Box::new(Expr::Binary(result, op, self.parse_expr(pre + 1)?)) + Expr::binary(result, op, self.parse_expr(pre + 1)?) } TK_BETWEEN => { eat_assert!(self, TK_BETWEEN); let start = self.parse_expr(pre)?; eat_expect!(self, TK_AND); let end = self.parse_expr(pre)?; - Box::new(Expr::Between { - lhs: result, - not, - start, - end, - }) + Expr::between(result, not, start, end) } TK_IN => { eat_assert!(self, TK_IN); @@ -1688,20 +1653,12 @@ impl<'a> Parser<'a> { TK_SELECT | TK_WITH | TK_VALUES => { let select = self.parse_select()?; eat_expect!(self, TK_RP); - Box::new(Expr::InSelect { - lhs: result, - not, - rhs: select, - }) + Expr::in_select(result, not, select) } _ => { let exprs = self.parse_expr_list()?; eat_expect!(self, TK_RP); - Box::new(Expr::InList { - lhs: result, - not, - rhs: exprs, - }) + Expr::in_list(result, not, exprs) } } } @@ -1715,13 +1672,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_RP); } } - - Box::new(Expr::InTable { - lhs: result, - not, - rhs: name, - args: exprs, - }) + Expr::in_table(result, not, name, exprs) } } } @@ -1755,134 +1706,72 @@ impl<'a> Parser<'a> { None }; - Box::new(Expr::Like { - lhs: result, - not, - op, - rhs: expr, - escape, - }) + Expr::like(result, not, op, expr, escape) } TK_ISNULL => { eat_assert!(self, TK_ISNULL); - Box::new(Expr::IsNull(result)) + Expr::is_null(result) } TK_NOTNULL => { eat_assert!(self, TK_NOTNULL); - Box::new(Expr::NotNull(result)) + Expr::not_null(result) } TK_LT => { eat_assert!(self, TK_LT); - Box::new(Expr::Binary( - result, - Operator::Less, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Less, self.parse_expr(pre + 1)?) } TK_GT => { eat_assert!(self, TK_GT); - Box::new(Expr::Binary( - result, - Operator::Greater, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Greater, self.parse_expr(pre + 1)?) } TK_LE => { eat_assert!(self, TK_LE); - Box::new(Expr::Binary( - result, - Operator::LessEquals, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::LessEquals, self.parse_expr(pre + 1)?) } TK_GE => { eat_assert!(self, TK_GE); - Box::new(Expr::Binary( - result, - Operator::GreaterEquals, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::GreaterEquals, self.parse_expr(pre + 1)?) } TK_ESCAPE => unreachable!(), TK_BITAND => { eat_assert!(self, TK_BITAND); - Box::new(Expr::Binary( - result, - Operator::BitwiseAnd, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::BitwiseAnd, self.parse_expr(pre + 1)?) } TK_BITOR => { eat_assert!(self, TK_BITOR); - Box::new(Expr::Binary( - result, - Operator::BitwiseOr, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::BitwiseOr, self.parse_expr(pre + 1)?) } TK_LSHIFT => { eat_assert!(self, TK_LSHIFT); - Box::new(Expr::Binary( - result, - Operator::LeftShift, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::LeftShift, self.parse_expr(pre + 1)?) } TK_RSHIFT => { eat_assert!(self, TK_RSHIFT); - Box::new(Expr::Binary( - result, - Operator::RightShift, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::RightShift, self.parse_expr(pre + 1)?) } TK_PLUS => { eat_assert!(self, TK_PLUS); - Box::new(Expr::Binary( - result, - Operator::Add, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Add, self.parse_expr(pre + 1)?) } TK_MINUS => { eat_assert!(self, TK_MINUS); - Box::new(Expr::Binary( - result, - Operator::Subtract, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Subtract, self.parse_expr(pre + 1)?) } TK_STAR => { eat_assert!(self, TK_STAR); - Box::new(Expr::Binary( - result, - Operator::Multiply, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Multiply, self.parse_expr(pre + 1)?) } TK_SLASH => { eat_assert!(self, TK_SLASH); - Box::new(Expr::Binary( - result, - Operator::Divide, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Divide, self.parse_expr(pre + 1)?) } TK_REM => { eat_assert!(self, TK_REM); - Box::new(Expr::Binary( - result, - Operator::Modulus, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Modulus, self.parse_expr(pre + 1)?) } TK_CONCAT => { eat_assert!(self, TK_CONCAT); - Box::new(Expr::Binary( - result, - Operator::Concat, - self.parse_expr(pre + 1)?, - )) + Expr::binary(result, Operator::Concat, self.parse_expr(pre + 1)?) } TK_PTR => { let tok = eat_assert!(self, TK_PTR); @@ -1892,11 +1781,13 @@ impl<'a> Parser<'a> { Operator::ArrowRightShift }; - Box::new(Expr::Binary(result, op, self.parse_expr(pre + 1)?)) + Expr::binary(result, op, self.parse_expr(pre + 1)?) } - TK_COLLATE => Box::new(Expr::Collate(result, self.parse_collate()?.unwrap())), + TK_COLLATE => Expr::collate(result, self.parse_collate()?.unwrap()), _ => unreachable!(), - } + }; + + result = expr; } Ok(result) @@ -2114,7 +2005,10 @@ impl<'a> Parser<'a> { _ => None, }; - Ok(Some(GroupBy { exprs, having })) + Ok(Some(GroupBy { + exprs, + having: having.map(Box::new), + })) } fn parse_where(&mut self) -> Result>> { @@ -2124,7 +2018,7 @@ impl<'a> Parser<'a> { TK_WHERE => { eat_assert!(self, TK_WHERE); let expr = self.parse_expr(0)?; - Ok(Some(expr)) + Ok(Some(expr.into_boxed())) } _ => Ok(None), }, @@ -2186,7 +2080,7 @@ impl<'a> Parser<'a> { TK_ON => { eat_assert!(self, TK_ON); let expr = self.parse_expr(0)?; - Ok(Some(JoinConstraint::On(expr))) + Ok(Some(JoinConstraint::On(expr.into_boxed()))) } TK_USING => { eat_assert!(self, TK_USING); @@ -2414,7 +2308,7 @@ impl<'a> Parser<'a> { let expr = self.parse_expr(0)?; let alias = self.parse_as()?; - Ok(ResultColumn::Expr(expr, alias)) + Ok(ResultColumn::Expr(expr.into_boxed(), alias)) } } } @@ -2436,7 +2330,7 @@ impl<'a> Parser<'a> { } #[allow(clippy::vec_box)] - fn parse_nexpr_list(&mut self) -> Result>> { + fn parse_nexpr_list(&mut self) -> Result> { let mut result = vec![self.parse_expr(0)?]; while let Some(tok) = self.peek()? { if tok.token_type == Some(TK_COMMA) { @@ -2534,7 +2428,7 @@ impl<'a> Parser<'a> { } fn parse_sorted_column(&mut self) -> Result { - let expr = self.parse_expr(0)?; + let expr = self.parse_expr(0)?.into_boxed(); let sort_order = self.parse_sort_order()?; let nulls = match self.peek()? { @@ -2598,7 +2492,7 @@ impl<'a> Parser<'a> { return Ok(None); } - let limit = self.parse_expr(0)?; + let limit = self.parse_expr(0)?.into_boxed(); let offset = match self.peek()? { Some(tok) => match tok.token_type.unwrap() { TK_OFFSET | TK_COMMA => { @@ -2608,7 +2502,8 @@ impl<'a> Parser<'a> { _ => None, }, _ => None, - }; + } + .map(Box::new); Ok(Some(Limit { expr: limit, @@ -2665,7 +2560,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_LP); let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); - Ok(TableConstraint::Check(expr)) + Ok(TableConstraint::Check(expr.into_boxed())) } fn parse_foreign_key_table_constraint(&mut self) -> Result { @@ -2952,7 +2847,7 @@ impl<'a> Parser<'a> { _ => None, }; - Ok(Stmt::Attach { expr, db_name, key }) + Ok(Stmt::attach(expr, db_name, key)) } fn parse_detach(&mut self) -> Result { @@ -2962,23 +2857,20 @@ impl<'a> Parser<'a> { } Ok(Stmt::Detach { - name: self.parse_expr(0)?, + name: self.parse_expr(0)?.into_boxed(), }) } fn parse_pragma_value(&mut self) -> Result { - match self.peek_no_eof()?.token_type.unwrap().fallback_id_if_ok() { + let expr = match self.peek_no_eof()?.token_type.unwrap().fallback_id_if_ok() { TK_ON | TK_DELETE | TK_DEFAULT => { let tok = eat_assert!(self, TK_ON, TK_DELETE, TK_DEFAULT); - Ok(Box::new(Expr::Literal(Literal::Keyword(from_bytes( - tok.value, - ))))) + Expr::Literal(Literal::Keyword(from_bytes(tok.value))) } - TK_ID | TK_STRING | TK_INDEXED | TK_JOIN_KW => { - Ok(Box::new(Expr::Name(self.parse_nm()?))) - } - _ => self.parse_signed(), - } + TK_ID | TK_STRING | TK_INDEXED | TK_JOIN_KW => Expr::Name(self.parse_nm()?), + _ => self.parse_signed()?, + }; + Ok(Box::new(expr)) } fn parse_pragma(&mut self) -> Result { @@ -3027,7 +2919,7 @@ impl<'a> Parser<'a> { _ => None, }; - Ok(Stmt::Vacuum { name, into }) + Ok(Stmt::vacuum(name, into)) } fn parse_term(&mut self) -> Result> { @@ -3041,7 +2933,7 @@ impl<'a> Parser<'a> { TK_CTIME_KW, ); - self.parse_expr_operand() + self.parse_expr_operand().map(Box::new) } fn parse_default_column_constraint(&mut self) -> Result { @@ -3172,7 +3064,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_LP); let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); - Ok(ColumnConstraint::Check(expr)) + Ok(ColumnConstraint::Check(expr.into_boxed())) } fn parse_ref_act(&mut self) -> Result { @@ -3293,7 +3185,7 @@ impl<'a> Parser<'a> { } eat_expect!(self, TK_LP); - let expr = self.parse_expr(0)?; + let expr = self.parse_expr(0)?.into_boxed(); eat_expect!(self, TK_RP); let typ = match self.peek()? { @@ -3529,7 +3421,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_EQ); Ok(Set { col_names: names, - expr: self.parse_expr(0)?, + expr: self.parse_expr(0)?.into_boxed(), }) } _ => { @@ -3537,7 +3429,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_EQ); Ok(Set { col_names: vec![name], - expr: self.parse_expr(0)?, + expr: self.parse_expr(0)?.into_boxed(), }) } } @@ -3766,7 +3658,8 @@ impl<'a> Parser<'a> { Some(self.parse_expr(0)?) } _ => None, - }; + } + .map(Box::new); eat_expect!(self, TK_BEGIN); @@ -4253,9 +4146,9 @@ mod tests { select: OneSelect::Select { distinctness: None, columns: vec![ResultColumn::Expr( - Box::new(Expr::Parenthesized(vec![Box::new(Expr::Literal( + Expr::Parenthesized(vec![Expr::Literal( Literal::Numeric("1".to_owned()), - ))])), + )]).into_boxed(), None, )], from: None, @@ -4732,8 +4625,8 @@ mod tests { Box::new(Expr::Case { base: None, when_then_pairs: vec![( - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), )], else_expr: Some(Box::new(Expr::Literal(Literal::Numeric( "3".to_owned(), @@ -4765,8 +4658,8 @@ mod tests { "4".to_owned(), )))), when_then_pairs: vec![( - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), )], else_expr: Some(Box::new(Expr::Literal(Literal::Numeric( "3".to_owned(), @@ -4798,8 +4691,8 @@ mod tests { "4".to_owned(), )))), when_then_pairs: vec![( - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), )], else_expr: None, }), @@ -5126,8 +5019,8 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { @@ -5164,17 +5057,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: None, - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: None, })), @@ -5205,17 +5098,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: None, })), @@ -5246,17 +5139,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![ SortedColumn { expr: Box::new(Expr::Id(Name::Ident("test".to_owned()))), @@ -5293,17 +5186,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Rows, @@ -5339,17 +5232,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Range, @@ -5385,17 +5278,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5431,17 +5324,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5477,17 +5370,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5523,17 +5416,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5569,17 +5462,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5615,17 +5508,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5663,17 +5556,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5711,17 +5604,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5757,17 +5650,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5803,17 +5696,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5849,17 +5742,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Box::new(Expr::Id(Name::Ident( + partition_by: vec![Expr::Id(Name::Ident( "product".to_owned(), - )))], + ))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -6323,9 +6216,9 @@ mod tests { lhs: Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), not: false, rhs: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), + Expr::Literal(Literal::Numeric("3".to_owned())), ], }), Operator::And, @@ -6362,9 +6255,9 @@ mod tests { alias: None, }, args: vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), + Expr::Literal(Literal::Numeric("3".to_owned())), ], }), Operator::And, @@ -7058,16 +6951,16 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], vec![ - Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("4".to_owned()))), + Expr::Literal(Literal::Numeric("3".to_owned())), + Expr::Literal(Literal::Numeric("4".to_owned())), ], vec![ - Box::new(Expr::Literal(Literal::Numeric("5".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("6".to_owned()))), + Expr::Literal(Literal::Numeric("5".to_owned())), + Expr::Literal(Literal::Numeric("6".to_owned())), ], ]), compounds: vec![], @@ -7747,8 +7640,8 @@ mod tests { select: Box::new(SelectTable::TableCall( QualifiedName { db_name: None, name: Name::Ident("foo".to_owned()), alias: None }, vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], None, )), @@ -8480,8 +8373,8 @@ mod tests { table: Box::new(SelectTable::TableCall( QualifiedName { db_name: None, name: Name::Ident("bar".to_owned()), alias: None }, vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], None, )), @@ -8525,12 +8418,12 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))) + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())) ], vec![ - Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("4".to_owned()))) + Expr::Literal(Literal::Numeric("3".to_owned())), + Expr::Literal(Literal::Numeric("4".to_owned())) ], ]), compounds: vec![], @@ -8654,11 +8547,11 @@ mod tests { where_clause: None, group_by: Some(GroupBy { exprs: vec![ - Box::new(Expr::Binary( + Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), Operator::Equals, Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - )), + ), ], having: None, }), @@ -8692,11 +8585,11 @@ mod tests { where_clause: None, group_by: Some(GroupBy { exprs: vec![ - Box::new(Expr::Binary( + Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), Operator::Equals, Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - )), + ), ], having: Some(Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), @@ -8734,11 +8627,11 @@ mod tests { where_clause: None, group_by: Some(GroupBy { exprs: vec![ - Box::new(Expr::Binary( + Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), Operator::Equals, Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - )), + ), ], having: Some(Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), @@ -8814,7 +8707,7 @@ mod tests { window: Window { base: None, partition_by: vec![ - Box::new(Expr::Id(Name::Ident("product".to_owned()))), + Expr::Id(Name::Ident("product".to_owned())), ], order_by: vec![], frame_clause: None, @@ -8852,7 +8745,7 @@ mod tests { window: Window { base: None, partition_by: vec![ - Box::new(Expr::Id(Name::Ident("product".to_owned()))), + Expr::Id(Name::Ident("product".to_owned())), ], order_by: vec![], frame_clause: None, @@ -8863,7 +8756,7 @@ mod tests { window: Window { base: None, partition_by: vec![ - Box::new(Expr::Id(Name::Ident("product_2".to_owned()))), + Expr::Id(Name::Ident("product_2".to_owned())), ], order_by: vec![], frame_clause: None, @@ -9117,7 +9010,7 @@ mod tests { name: None, constraint: ColumnConstraint::Default( Box::new(Expr::Parenthesized(vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), ])) ), }, @@ -10349,8 +10242,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -10396,8 +10289,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -10440,8 +10333,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -10493,8 +10386,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -10564,8 +10457,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -10632,8 +10525,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -11140,8 +11033,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], @@ -11168,8 +11061,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Expr::Literal(Literal::Numeric("1".to_owned())), + Expr::Literal(Literal::Numeric("2".to_owned())), ], ]), compounds: vec![], From 75c85e6284d06be35bd38c5b2d700b4596bc734a Mon Sep 17 00:00:00 2001 From: C4 Patino Date: Mon, 25 Aug 2025 16:40:55 -0500 Subject: [PATCH 55/73] ci: fix merge-pr issue to escape command-line backticks --- scripts/merge-pr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/merge-pr.py b/scripts/merge-pr.py index 4ff2183a7..3aa48c683 100755 --- a/scripts/merge-pr.py +++ b/scripts/merge-pr.py @@ -9,6 +9,7 @@ import json import os import re +import shlex import subprocess import sys import tempfile @@ -112,8 +113,10 @@ def merge_remote(pr_number: int, commit_message: str, commit_title: str): try: print(f"\nMerging PR #{pr_number} with custom commit message...") + # Use gh pr merge with the commit message file - cmd = f'gh pr merge {pr_number} --merge --subject "{commit_title}" --body-file "{temp_file_path}"' + safe_title = shlex.quote(commit_title) + cmd = f'gh pr merge {pr_number} --merge --subject {safe_title} --body-file "{temp_file_path}"' output, error, returncode = run_command(cmd, capture_output=False) if returncode == 0: From 8c64b772e7150f0eb4019de89728b08c049946db Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 25 Aug 2025 19:04:14 -0400 Subject: [PATCH 56/73] Use previous WindowsIO impl as generic IO --- core/io/generic.rs | 73 ++++++++++++++++++++-------------------------- 1 file changed, 32 insertions(+), 41 deletions(-) diff --git a/core/io/generic.rs b/core/io/generic.rs index 9a87c6c63..83caa1405 100644 --- a/core/io/generic.rs +++ b/core/io/generic.rs @@ -1,24 +1,20 @@ -use super::MemoryIO; -use crate::{Clock, Completion, CompletionType, File, Instant, LimboError, OpenFlags, Result, IO}; -use std::cell::RefCell; +use crate::{Clock, Completion, File, Instant, LimboError, OpenFlags, Result, IO}; +use parking_lot::RwLock; use std::io::{Read, Seek, Write}; use std::sync::Arc; -use tracing::{debug, trace}; - +use tracing::{debug, instrument, trace, Level}; pub struct GenericIO {} impl GenericIO { pub fn new() -> Result { - debug!("Using IO backend 'generic'"); + debug!("Using IO backend 'syscall'"); Ok(Self {}) } } -unsafe impl Send for GenericIO {} -unsafe impl Sync for GenericIO {} - impl IO for GenericIO { - fn open_file(&self, path: &str, flags: OpenFlags, _direct: bool) -> Result> { + #[instrument(err, skip_all, level = Level::TRACE)] + fn open_file(&self, path: &str, flags: OpenFlags, direct: bool) -> Result> { trace!("open_file(path = {})", path); let mut file = std::fs::File::options(); file.read(true); @@ -30,17 +26,17 @@ impl IO for GenericIO { let file = file.open(path)?; Ok(Arc::new(GenericFile { - file: RefCell::new(file), - memory_io: Arc::new(MemoryIO::new()), + file: RwLock::new(file), })) } - + + #[instrument(err, skip_all, level = Level::TRACE)] fn remove_file(&self, path: &str) -> Result<()> { trace!("remove_file(path = {})", path); - std::fs::remove_file(path)?; - Ok(()) + Ok(std::fs::remove_file(path)?) } + #[instrument(err, skip_all, level = Level::TRACE)] fn run_once(&self) -> Result<()> { Ok(()) } @@ -57,68 +53,63 @@ impl Clock for GenericIO { } pub struct GenericFile { - file: RefCell, - memory_io: Arc, + file: RwLock, } -unsafe impl Send for GenericFile {} -unsafe impl Sync for GenericFile {} - impl File for GenericFile { - // Since we let the OS handle the locking, file locking is not supported on the generic IO implementation - // No-op implementation allows compilation but provides no actual file locking. - fn lock_file(&self, _exclusive: bool) -> Result<()> { - Ok(()) + #[instrument(err, skip_all, level = Level::TRACE)] + fn lock_file(&self, exclusive: bool) -> Result<()> { + unimplemented!() } + #[instrument(err, skip_all, level = Level::TRACE)] fn unlock_file(&self) -> Result<()> { - Ok(()) + unimplemented!() } + #[instrument(skip(self, c), level = Level::TRACE)] fn pread(&self, pos: usize, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let mut file = self.file.write(); file.seek(std::io::SeekFrom::Start(pos as u64))?; - { + let nr = { let r = c.as_read(); - let mut buf = r.buf(); + let buf = r.buf(); let buf = buf.as_mut_slice(); file.read_exact(buf)?; - } - c.complete(0); + buf.len() as i32 + }; + c.complete(nr); Ok(c) } + #[instrument(skip(self, c, buffer), level = Level::TRACE)] fn pwrite(&self, pos: usize, buffer: Arc, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let mut file = self.file.write(); file.seek(std::io::SeekFrom::Start(pos as u64))?; let buf = buffer.as_slice(); file.write_all(buf)?; - c.complete(buf.len() as i32); + c.complete(buffer.len() as i32); Ok(c) } + #[instrument(err, skip_all, level = Level::TRACE)] fn sync(&self, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let file = self.file.write(); file.sync_all()?; c.complete(0); Ok(c) } + #[instrument(err, skip_all, level = Level::TRACE)] fn truncate(&self, len: usize, c: Completion) -> Result { - let mut file = self.file.borrow_mut(); + let file = self.file.write(); file.set_len(len as u64)?; c.complete(0); Ok(c) } fn size(&self) -> Result { - let file = self.file.borrow(); + let file = self.file.read(); Ok(file.metadata().unwrap().len()) } } - -impl Drop for GenericFile { - fn drop(&mut self) { - self.unlock_file().expect("Failed to unlock file"); - } -} From 34da6611c19442b0090f325aef3052c9484364d2 Mon Sep 17 00:00:00 2001 From: Alex Miller Date: Mon, 25 Aug 2025 17:41:34 -0700 Subject: [PATCH 57/73] Update TPC-H running instructions in PERF.md Closes #2756 --- PERF.md | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/PERF.md b/PERF.md index 7ddf5eab0..ed09730cf 100644 --- a/PERF.md +++ b/PERF.md @@ -29,6 +29,7 @@ strace -f -c ../../Mobibench/shell/mobibench-turso -f 1024 -r 4 -a 0 -y 0 -t 1 - ./mobibench -p -n 1000 -d 0 -j 4 ``` + ## Clickbench We have a modified version of the Clickbench benchmark script that can be run with: @@ -41,7 +42,6 @@ This will build Turso in release mode, create a database, and run the benchmarks It will run the queries for both Turso and SQLite, and print the results. - ## Comparing VFS's/IO Back-ends (io_uring | syscall) ```shell @@ -54,26 +54,9 @@ The naive script will build and run limbo in release mode and execute the given ## TPC-H -1. Clone the Taratool TPC-H benchmarking tool: +Run the benchmark script: ```shell -git clone git@github.com:tarantool/tpch.git +./perf/tpc-h/benchmark.sh ``` -2. Patch the benchmark runner script: - -```patch -diff --git a/bench_queries.sh b/bench_queries.sh -index 6b894f9..c808e9a 100755 ---- a/bench_queries.sh -+++ b/bench_queries.sh -@@ -4,7 +4,7 @@ function check_q { - local query=queries/$*.sql - ( - echo $query -- time ( sqlite3 TPC-H.db < $query > /dev/null ) -+ time ( ../../limbo/target/release/limbo -m list TPC-H.db < $query > /dev/null ) - ) - } -``` - From b16f96b507baf612fdf8de1190d5961326a499e5 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 14:58:24 -0300 Subject: [PATCH 58/73] create sql_generation crate --- Cargo.lock | 4 ++++ Cargo.toml | 1 + sql_generation/Cargo.toml | 12 ++++++++++++ sql_generation/lib.rs | 1 + 4 files changed, 18 insertions(+) create mode 100644 sql_generation/Cargo.toml create mode 100644 sql_generation/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 81a4300c0..d7db34d08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3462,6 +3462,10 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d372029cb5195f9ab4e4b9aef550787dce78b124fcaee8d82519925defcd6f0d" +[[package]] +name = "sql_generation" +version = "0.1.4" + [[package]] name = "sqlparser_bench" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 646cd977c..f61b10ecc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ members = [ "parser", "sync/engine", "sync/javascript", + "sql_generation", ] exclude = ["perf/latency/limbo"] diff --git a/sql_generation/Cargo.toml b/sql_generation/Cargo.toml new file mode 100644 index 000000000..b4b5dbbcf --- /dev/null +++ b/sql_generation/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sql_generation" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +path = "lib.rs" + +[dependencies] diff --git a/sql_generation/lib.rs b/sql_generation/lib.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/sql_generation/lib.rs @@ -0,0 +1 @@ + From 0285bdd72ce4a72f3ef4e061e5101ba0415f8eac Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 15:14:10 -0300 Subject: [PATCH 59/73] copy generation code from simulator --- Cargo.lock | 25 +- Cargo.toml | 3 + sql_generation/Cargo.toml | 12 + sql_generation/generation/expr.rs | 296 ++++ sql_generation/generation/mod.rs | 166 ++ sql_generation/generation/plan.rs | 833 +++++++++ sql_generation/generation/predicate/binary.rs | 586 +++++++ sql_generation/generation/predicate/mod.rs | 378 ++++ sql_generation/generation/predicate/unary.rs | 306 ++++ sql_generation/generation/property.rs | 1533 +++++++++++++++++ sql_generation/generation/query.rs | 447 +++++ sql_generation/generation/table.rs | 258 +++ sql_generation/lib.rs | 3 +- sql_generation/model/mod.rs | 4 + sql_generation/model/query/create.rs | 45 + sql_generation/model/query/create_index.rs | 106 ++ sql_generation/model/query/delete.rs | 41 + sql_generation/model/query/drop.rs | 34 + sql_generation/model/query/insert.rs | 87 + sql_generation/model/query/mod.rs | 129 ++ sql_generation/model/query/predicate.rs | 146 ++ sql_generation/model/query/select.rs | 496 ++++++ sql_generation/model/query/transaction.rs | 60 + sql_generation/model/query/update.rs | 71 + sql_generation/model/table.rs | 428 +++++ 25 files changed, 6490 insertions(+), 3 deletions(-) create mode 100644 sql_generation/generation/expr.rs create mode 100644 sql_generation/generation/mod.rs create mode 100644 sql_generation/generation/plan.rs create mode 100644 sql_generation/generation/predicate/binary.rs create mode 100644 sql_generation/generation/predicate/mod.rs create mode 100644 sql_generation/generation/predicate/unary.rs create mode 100644 sql_generation/generation/property.rs create mode 100644 sql_generation/generation/query.rs create mode 100644 sql_generation/generation/table.rs create mode 100644 sql_generation/model/mod.rs create mode 100644 sql_generation/model/query/create.rs create mode 100644 sql_generation/model/query/create_index.rs create mode 100644 sql_generation/model/query/delete.rs create mode 100644 sql_generation/model/query/drop.rs create mode 100644 sql_generation/model/query/insert.rs create mode 100644 sql_generation/model/query/mod.rs create mode 100644 sql_generation/model/query/predicate.rs create mode 100644 sql_generation/model/query/select.rs create mode 100644 sql_generation/model/query/transaction.rs create mode 100644 sql_generation/model/query/update.rs create mode 100644 sql_generation/model/table.rs diff --git a/Cargo.lock b/Cargo.lock index d7db34d08..1569f69ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,6 +119,15 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "anarchist-readable-name-generator-lib" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09a645c34bad5551ed4b2496536985efdc4373b097c0e57abf2eb14774538278" +dependencies = [ + "rand 0.9.2", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -2125,7 +2134,7 @@ dependencies = [ name = "limbo_sim" version = "0.1.4" dependencies = [ - "anarchist-readable-name-generator-lib", + "anarchist-readable-name-generator-lib 0.1.2", "anyhow", "chrono", "clap", @@ -3465,6 +3474,18 @@ checksum = "d372029cb5195f9ab4e4b9aef550787dce78b124fcaee8d82519925defcd6f0d" [[package]] name = "sql_generation" version = "0.1.4" +dependencies = [ + "anarchist-readable-name-generator-lib 0.2.0", + "anyhow", + "hex", + "itertools 0.14.0", + "rand 0.9.2", + "rand_chacha 0.9.0", + "serde", + "tracing", + "turso_core", + "turso_parser", +] [[package]] name = "sqlparser_bench" @@ -4163,7 +4184,7 @@ dependencies = [ name = "turso_stress" version = "0.1.4" dependencies = [ - "anarchist-readable-name-generator-lib", + "anarchist-readable-name-generator-lib 0.1.2", "antithesis_sdk", "clap", "hex", diff --git a/Cargo.toml b/Cargo.toml index f61b10ecc..092d76d98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,9 @@ serde_json = "1.0" anyhow = "1.0.98" mimalloc = { version = "0.1.47", default-features = false } rusqlite = { version = "0.37.0", features = ["bundled"] } +itertools = "0.14.0" +rand = "0.9.2" +tracing = "0.1.41" [profile.release] debug = "line-tables-only" diff --git a/sql_generation/Cargo.toml b/sql_generation/Cargo.toml index b4b5dbbcf..d84d08380 100644 --- a/sql_generation/Cargo.toml +++ b/sql_generation/Cargo.toml @@ -10,3 +10,15 @@ repository.workspace = true path = "lib.rs" [dependencies] +hex = "0.4.3" +serde = { workspace = true, features = ["derive"] } +turso_core = { workspace = true, features = ["simulator"] } +turso_parser = { workspace = true, features = ["serde"] } +rand = { workspace = true } +anarchist-readable-name-generator-lib = "0.2.0" +itertools = { workspace = true } +anyhow = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +rand_chacha = "0.9.0" diff --git a/sql_generation/generation/expr.rs b/sql_generation/generation/expr.rs new file mode 100644 index 000000000..c5e33758c --- /dev/null +++ b/sql_generation/generation/expr.rs @@ -0,0 +1,296 @@ +use turso_parser::ast::{ + self, Expr, LikeOperator, Name, Operator, QualifiedName, Type, UnaryOperator, +}; + +use crate::{ + generation::{ + frequency, gen_random_text, one_of, pick, pick_index, Arbitrary, ArbitraryFrom, + ArbitrarySizedFrom, + }, + model::table::SimValue, +}; + +impl Arbitrary for Box +where + T: Arbitrary, +{ + fn arbitrary(rng: &mut R) -> Self { + Box::from(T::arbitrary(rng)) + } +} + +impl ArbitrarySizedFrom for Box +where + T: ArbitrarySizedFrom, +{ + fn arbitrary_sized_from(rng: &mut R, t: A, size: usize) -> Self { + Box::from(T::arbitrary_sized_from(rng, t, size)) + } +} + +impl Arbitrary for Option +where + T: Arbitrary, +{ + fn arbitrary(rng: &mut R) -> Self { + rng.random_bool(0.5).then_some(T::arbitrary(rng)) + } +} + +impl ArbitrarySizedFrom for Option +where + T: ArbitrarySizedFrom, +{ + fn arbitrary_sized_from(rng: &mut R, t: A, size: usize) -> Self { + rng.random_bool(0.5) + .then_some(T::arbitrary_sized_from(rng, t, size)) + } +} + +impl ArbitraryFrom for Vec +where + T: ArbitraryFrom, +{ + fn arbitrary_from(rng: &mut R, t: A) -> Self { + let size = rng.random_range(0..5); + (0..size).map(|_| T::arbitrary_from(rng, t)).collect() + } +} + +// Freestyling generation +impl ArbitrarySizedFrom<&SimulatorEnv> for Expr { + fn arbitrary_sized_from(rng: &mut R, t: &SimulatorEnv, size: usize) -> Self { + frequency( + vec![ + ( + 1, + Box::new(|rng| Expr::Literal(ast::Literal::arbitrary_from(rng, t))), + ), + ( + size, + Box::new(|rng| { + one_of( + vec![ + // Box::new(|rng: &mut R| Expr::Between { + // lhs: Box::arbitrary_sized_from(rng, t, size - 1), + // not: rng.gen_bool(0.5), + // start: Box::arbitrary_sized_from(rng, t, size - 1), + // end: Box::arbitrary_sized_from(rng, t, size - 1), + // }), + Box::new(|rng: &mut R| { + Expr::Binary( + Box::arbitrary_sized_from(rng, t, size - 1), + Operator::arbitrary(rng), + Box::arbitrary_sized_from(rng, t, size - 1), + ) + }), + // Box::new(|rng| Expr::Case { + // base: Option::arbitrary_from(rng, t), + // when_then_pairs: { + // let size = rng.gen_range(0..5); + // (0..size) + // .map(|_| (Self::arbitrary_from(rng, t), Self::arbitrary_from(rng, t))) + // .collect() + // }, + // else_expr: Option::arbitrary_from(rng, t), + // }), + // Box::new(|rng| Expr::Cast { + // expr: Box::arbitrary_sized_from(rng, t), + // type_name: Option::arbitrary(rng), + // }), + // Box::new(|rng| Expr::Collate(Box::arbitrary_sized_from(rng, t), CollateName::arbitrary(rng).0)), + // Box::new(|rng| Expr::InList { + // lhs: Box::arbitrary_sized_from(rng, t), + // not: rng.gen_bool(0.5), + // rhs: Option::arbitrary_from(rng, t), + // }), + // Box::new(|rng| Expr::IsNull(Box::arbitrary_sized_from(rng, t))), + // Box::new(|rng| { + // // let op = LikeOperator::arbitrary_from(rng, t); + // let op = ast::LikeOperator::Like; // todo: remove this line when LikeOperator is implemented + // let escape = if matches!(op, LikeOperator::Like) { + // Option::arbitrary_sized_from(rng, t, size - 1) + // } else { + // None + // }; + // Expr::Like { + // lhs: Box::arbitrary_sized_from(rng, t, size - 1), + // not: rng.gen_bool(0.5), + // op, + // rhs: Box::arbitrary_sized_from(rng, t, size - 1), + // escape, + // } + // }), + // Box::new(|rng| Expr::NotNull(Box::arbitrary_sized_from(rng, t))), + // // TODO: only supports one paranthesized expression + // Box::new(|rng| Expr::Parenthesized(vec![Expr::arbitrary_from(rng, t)])), + // Box::new(|rng| { + // let table_idx = pick_index(t.tables.len(), rng); + // let table = &t.tables[table_idx]; + // let col_idx = pick_index(table.columns.len(), rng); + // let col = &table.columns[col_idx]; + // Expr::Qualified(Name(table.name.clone()), Name(col.name.clone())) + // }) + Box::new(|rng| { + Expr::Unary( + UnaryOperator::arbitrary_from(rng, t), + Box::arbitrary_sized_from(rng, t, size - 1), + ) + }), + // TODO: skip Exists for now + // TODO: skip Function Call for now + // TODO: skip Function Call Star for now + // TODO: skip ID for now + // TODO: skip InSelect as still need to implement ArbitratyFrom for Select + // TODO: skip InTable + // TODO: skip Name + // TODO: Skip DoublyQualified for now + // TODO: skip Raise + // TODO: skip subquery + ], + rng, + ) + }), + ), + ], + rng, + ) + } +} + +impl Arbitrary for Operator { + fn arbitrary(rng: &mut R) -> Self { + let choices = [ + Operator::Add, + Operator::And, + // Operator::ArrowRight, -- todo: not implemented in `binary_compare` yet + // Operator::ArrowRightShift, -- todo: not implemented in `binary_compare` yet + Operator::BitwiseAnd, + // Operator::BitwiseNot, -- todo: not implemented in `binary_compare` yet + Operator::BitwiseOr, + // Operator::Concat, -- todo: not implemented in `exec_concat` + Operator::Divide, + Operator::Equals, + Operator::Greater, + Operator::GreaterEquals, + Operator::Is, + Operator::IsNot, + Operator::LeftShift, + Operator::Less, + Operator::LessEquals, + Operator::Modulus, + Operator::Multiply, + Operator::NotEquals, + Operator::Or, + Operator::RightShift, + Operator::Subtract, + ]; + *pick(&choices, rng) + } +} + +impl Arbitrary for Type { + fn arbitrary(rng: &mut R) -> Self { + let name = pick(&["INT", "INTEGER", "REAL", "TEXT", "BLOB", "ANY"], rng).to_string(); + Self { + name, + size: None, // TODO: come back later here + } + } +} + +struct CollateName(String); + +impl Arbitrary for CollateName { + fn arbitrary(rng: &mut R) -> Self { + let choice = rng.random_range(0..3); + CollateName( + match choice { + 0 => "BINARY", + 1 => "RTRIM", + 2 => "NOCASE", + _ => unreachable!(), + } + .to_string(), + ) + } +} + +impl ArbitraryFrom<&SimulatorEnv> for QualifiedName { + fn arbitrary_from(rng: &mut R, t: &SimulatorEnv) -> Self { + // TODO: for now just generate table name + let table_idx = pick_index(t.tables.len(), rng); + let table = &t.tables[table_idx]; + // TODO: for now forego alias + Self { + db_name: None, + name: Name::new(&table.name), + alias: None, + } + } +} + +impl ArbitraryFrom<&SimulatorEnv> for LikeOperator { + fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { + let choice = rng.random_range(0..4); + match choice { + 0 => LikeOperator::Glob, + 1 => LikeOperator::Like, + 2 => LikeOperator::Match, + 3 => LikeOperator::Regexp, + _ => unreachable!(), + } + } +} + +// Current implementation does not take into account the columns affinity nor if table is Strict +impl ArbitraryFrom<&SimulatorEnv> for ast::Literal { + fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { + loop { + let choice = rng.random_range(0..5); + let lit = match choice { + 0 => ast::Literal::Numeric({ + let integer = rng.random_bool(0.5); + if integer { + rng.random_range(i64::MIN..i64::MAX).to_string() + } else { + rng.random_range(-1e10..1e10).to_string() + } + }), + 1 => ast::Literal::String(format!("'{}'", gen_random_text(rng))), + 2 => ast::Literal::Blob(hex::encode(gen_random_text(rng).as_bytes())), + // TODO: skip Keyword + 3 => continue, + 4 => ast::Literal::Null, + // TODO: Ignore Date stuff for now + _ => continue, + }; + break lit; + } + } +} + +// Creates a litreal value +impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr { + fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { + if values.is_empty() { + return Self::Literal(ast::Literal::Null); + } + // TODO: for now just convert the value to an ast::Literal + let value = pick(values, rng); + Expr::Literal((*value).into()) + } +} + +impl ArbitraryFrom<&SimulatorEnv> for UnaryOperator { + fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { + let choice = rng.random_range(0..4); + match choice { + 0 => Self::BitwiseNot, + 1 => Self::Negative, + 2 => Self::Not, + 3 => Self::Positive, + _ => unreachable!(), + } + } +} diff --git a/sql_generation/generation/mod.rs b/sql_generation/generation/mod.rs new file mode 100644 index 000000000..44ae7f34d --- /dev/null +++ b/sql_generation/generation/mod.rs @@ -0,0 +1,166 @@ +use std::{iter::Sum, ops::SubAssign}; + +use anarchist_readable_name_generator_lib::readable_name_custom; +use rand::{distr::uniform::SampleUniform, Rng}; + +mod expr; +pub mod plan; +mod predicate; +pub mod property; +pub mod query; +pub mod table; + +type ArbitraryFromFunc<'a, R, T> = Box T + 'a>; +type Choice<'a, R, T> = (usize, Box Option + 'a>); + +/// Arbitrary trait for generating random values +/// An implementation of arbitrary is assumed to be a uniform sampling of +/// the possible values of the type, with a bias towards smaller values for +/// practicality. +pub trait Arbitrary { + fn arbitrary(rng: &mut R) -> Self; +} + +/// ArbitrarySized trait for generating random values of a specific size +/// An implementation of arbitrary_sized is assumed to be a uniform sampling of +/// the possible values of the type, with a bias towards smaller values for +/// practicality, but with the additional constraint that the generated value +/// must fit in the given size. This is useful for generating values that are +/// constrained by a specific size, such as integers or strings. +pub trait ArbitrarySized { + fn arbitrary_sized(rng: &mut R, size: usize) -> Self; +} + +/// ArbitraryFrom trait for generating random values from a given value +/// ArbitraryFrom allows for constructing relations, where the generated +/// value is dependent on the given value. These relations could be constraints +/// such as generating an integer within an interval, or a value that fits in a table, +/// or a predicate satisfying a given table row. +pub trait ArbitraryFrom { + fn arbitrary_from(rng: &mut R, t: T) -> Self; +} + +/// ArbitrarySizedFrom trait for generating random values from a given value +/// ArbitrarySizedFrom allows for constructing relations, where the generated +/// value is dependent on the given value and a size constraint. These relations +/// could be constraints such as generating an integer within an interval, +/// or a value that fits in a table, or a predicate satisfying a given table row, +/// but with the additional constraint that the generated value must fit in the given size. +/// This is useful for generating values that are constrained by a specific size, +/// such as integers or strings, while still being dependent on the given value. +pub trait ArbitrarySizedFrom { + fn arbitrary_sized_from(rng: &mut R, t: T, size: usize) -> Self; +} + +/// ArbitraryFromMaybe trait for fallibally generating random values from a given value +pub trait ArbitraryFromMaybe { + fn arbitrary_from_maybe(rng: &mut R, t: T) -> Option + where + Self: Sized; +} + +/// Frequency is a helper function for composing different generators with different frequency +/// of occurrences. +/// The type signature for the `N` parameter is a bit complex, but it +/// roughly corresponds to a type that can be summed, compared, subtracted and sampled, which are +/// the operations we require for the implementation. +// todo: switch to a simpler type signature that can accommodate all integer and float types, which +// should be enough for our purposes. +pub(crate) fn frequency< + T, + R: Rng, + N: Sum + PartialOrd + Copy + Default + SampleUniform + SubAssign, +>( + choices: Vec<(N, ArbitraryFromFunc)>, + rng: &mut R, +) -> T { + let total = choices.iter().map(|(weight, _)| *weight).sum::(); + let mut choice = rng.random_range(N::default()..total); + + for (weight, f) in choices { + if choice < weight { + return f(rng); + } + choice -= weight; + } + + unreachable!() +} + +/// one_of is a helper function for composing different generators with equal probability of occurrence. +pub(crate) fn one_of(choices: Vec>, rng: &mut R) -> T { + let index = rng.random_range(0..choices.len()); + choices[index](rng) +} + +/// backtrack is a helper function for composing different "failable" generators. +/// The function takes a list of functions that return an Option, along with number of retries +/// to make before giving up. +pub(crate) fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { + loop { + // If there are no more choices left, we give up + let choices_ = choices + .iter() + .enumerate() + .filter(|(_, (retries, _))| *retries > 0) + .collect::>(); + if choices_.is_empty() { + tracing::trace!("backtrack: no more choices left"); + return None; + } + // Run a one_of on the remaining choices + let (choice_index, choice) = pick(&choices_, rng); + let choice_index = *choice_index; + // If the choice returns None, we decrement the number of retries and try again + let result = choice.1(rng); + if result.is_some() { + return result; + } else { + choices[choice_index].0 -= 1; + } + } +} + +/// pick is a helper function for uniformly picking a random element from a slice +pub(crate) fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { + let index = rng.random_range(0..choices.len()); + &choices[index] +} + +/// pick_index is typically used for picking an index from a slice to later refer to the element +/// at that index. +pub(crate) fn pick_index(choices: usize, rng: &mut R) -> usize { + rng.random_range(0..choices) +} + +/// pick_n_unique is a helper function for uniformly picking N unique elements from a range. +/// The elements themselves are usize, typically representing indices. +pub(crate) fn pick_n_unique( + range: std::ops::Range, + n: usize, + rng: &mut R, +) -> Vec { + use rand::seq::SliceRandom; + let mut items: Vec = range.collect(); + items.shuffle(rng); + items.into_iter().take(n).collect() +} + +/// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random +/// readable names for tables, columns, text values etc. +pub(crate) fn gen_random_text(rng: &mut T) -> String { + let big_text = rng.random_ratio(1, 1000); + if big_text { + // let max_size: u64 = 2 * 1024 * 1024 * 1024; + let max_size: u64 = 2 * 1024; + let size = rng.random_range(1024..max_size); + let mut name = String::with_capacity(size as usize); + for i in 0..size { + name.push(((i % 26) as u8 + b'A') as char); + } + name + } else { + let name = readable_name_custom("_", rng); + name.replace("-", "_") + } +} diff --git a/sql_generation/generation/plan.rs b/sql_generation/generation/plan.rs new file mode 100644 index 000000000..eac9359b3 --- /dev/null +++ b/sql_generation/generation/plan.rs @@ -0,0 +1,833 @@ +use std::{ + collections::HashSet, + fmt::{Debug, Display}, + path::Path, + sync::Arc, + vec, +}; + +use serde::{Deserialize, Serialize}; + +use turso_core::{Connection, Result, StepResult}; + +use crate::{ + generation::query::SelectFree, + model::{ + query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Query, Select}, + table::SimValue, + }, + runner::{ + env::{SimConnection, SimulationType, SimulatorTables}, + io::SimulatorIO, + }, + SimulatorEnv, +}; + +use crate::generation::{frequency, Arbitrary, ArbitraryFrom}; + +use super::property::{remaining, Property}; + +pub(crate) type ResultSet = Result>>; + +#[derive(Clone, Serialize, Deserialize)] +pub(crate) struct InteractionPlan { + pub(crate) plan: Vec, +} + +impl InteractionPlan { + /// Compute via diff computes a a plan from a given `.plan` file without the need to parse + /// sql. This is possible because there are two versions of the plan file, one that is human + /// readable and one that is serialized as JSON. Under watch mode, the users will be able to + /// delete interactions from the human readable file, and this function uses the JSON file as + /// a baseline to detect with interactions were deleted and constructs the plan from the + /// remaining interactions. + pub(crate) fn compute_via_diff(plan_path: &Path) -> Vec> { + let interactions = std::fs::read_to_string(plan_path).unwrap(); + let interactions = interactions.lines().collect::>(); + + let plan: InteractionPlan = serde_json::from_str( + std::fs::read_to_string(plan_path.with_extension("json")) + .unwrap() + .as_str(), + ) + .unwrap(); + + let mut plan = plan + .plan + .into_iter() + .map(|i| i.interactions()) + .collect::>(); + + let (mut i, mut j) = (0, 0); + + while i < interactions.len() && j < plan.len() { + if interactions[i].starts_with("-- begin") + || interactions[i].starts_with("-- end") + || interactions[i].is_empty() + { + i += 1; + continue; + } + + // interactions[i] is the i'th line in the human readable plan + // plan[j][k] is the k'th interaction in the j'th property + let mut k = 0; + + while k < plan[j].len() { + if i >= interactions.len() { + let _ = plan.split_off(j + 1); + let _ = plan[j].split_off(k); + break; + } + tracing::error!("Comparing '{}' with '{}'", interactions[i], plan[j][k]); + if interactions[i].contains(plan[j][k].to_string().as_str()) { + i += 1; + k += 1; + } else { + plan[j].remove(k); + panic!("Comparing '{}' with '{}'", interactions[i], plan[j][k]); + } + } + + if plan[j].is_empty() { + plan.remove(j); + } else { + j += 1; + } + } + let _ = plan.split_off(j); + plan + } +} + +pub(crate) struct InteractionPlanState { + pub(crate) stack: Vec, + pub(crate) interaction_pointer: usize, + pub(crate) secondary_pointer: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum Interactions { + Property(Property), + Query(Query), + Fault(Fault), +} + +impl Interactions { + pub(crate) fn name(&self) -> Option<&str> { + match self { + Interactions::Property(property) => Some(property.name()), + Interactions::Query(_) => None, + Interactions::Fault(_) => None, + } + } + + pub(crate) fn interactions(&self) -> Vec { + match self { + Interactions::Property(property) => property.interactions(), + Interactions::Query(query) => vec![Interaction::Query(query.clone())], + Interactions::Fault(fault) => vec![Interaction::Fault(fault.clone())], + } + } +} + +impl Interactions { + pub(crate) fn dependencies(&self) -> HashSet { + match self { + Interactions::Property(property) => { + property + .interactions() + .iter() + .fold(HashSet::new(), |mut acc, i| match i { + Interaction::Query(q) => { + acc.extend(q.dependencies()); + acc + } + _ => acc, + }) + } + Interactions::Query(query) => query.dependencies(), + Interactions::Fault(_) => HashSet::new(), + } + } + + pub(crate) fn uses(&self) -> Vec { + match self { + Interactions::Property(property) => { + property + .interactions() + .iter() + .fold(vec![], |mut acc, i| match i { + Interaction::Query(q) => { + acc.extend(q.uses()); + acc + } + _ => acc, + }) + } + Interactions::Query(query) => query.uses(), + Interactions::Fault(_) => vec![], + } + } +} + +impl Display for InteractionPlan { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for interactions in &self.plan { + match interactions { + Interactions::Property(property) => { + let name = property.name(); + writeln!(f, "-- begin testing '{name}'")?; + for interaction in property.interactions() { + write!(f, "\t")?; + + match interaction { + Interaction::Query(query) => writeln!(f, "{query};")?, + Interaction::Assumption(assumption) => { + writeln!(f, "-- ASSUME {};", assumption.name)? + } + Interaction::Assertion(assertion) => { + writeln!(f, "-- ASSERT {};", assertion.name)? + } + Interaction::Fault(fault) => writeln!(f, "-- FAULT '{fault}';")?, + Interaction::FsyncQuery(query) => { + writeln!(f, "-- FSYNC QUERY;")?; + writeln!(f, "{query};")?; + writeln!(f, "{query};")? + } + Interaction::FaultyQuery(query) => { + writeln!(f, "{query}; -- FAULTY QUERY")? + } + } + } + writeln!(f, "-- end testing '{name}'")?; + } + Interactions::Fault(fault) => { + writeln!(f, "-- FAULT '{fault}'")?; + } + Interactions::Query(query) => { + writeln!(f, "{query};")?; + } + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct InteractionStats { + pub(crate) read_count: usize, + pub(crate) write_count: usize, + pub(crate) delete_count: usize, + pub(crate) update_count: usize, + pub(crate) create_count: usize, + pub(crate) create_index_count: usize, + pub(crate) drop_count: usize, + pub(crate) begin_count: usize, + pub(crate) commit_count: usize, + pub(crate) rollback_count: usize, +} + +impl Display for InteractionStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Read: {}, Write: {}, Delete: {}, Update: {}, Create: {}, CreateIndex: {}, Drop: {}, Begin: {}, Commit: {}, Rollback: {}", + self.read_count, + self.write_count, + self.delete_count, + self.update_count, + self.create_count, + self.create_index_count, + self.drop_count, + self.begin_count, + self.commit_count, + self.rollback_count, + ) + } +} + +#[derive(Debug)] +pub(crate) enum Interaction { + Query(Query), + Assumption(Assertion), + Assertion(Assertion), + Fault(Fault), + /// Will attempt to run any random query. However, when the connection tries to sync it will + /// close all connections and reopen the database and assert that no data was lost + FsyncQuery(Query), + FaultyQuery(Query), +} + +impl Display for Interaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Query(query) => write!(f, "{query}"), + Self::Assumption(assumption) => write!(f, "ASSUME {}", assumption.name), + Self::Assertion(assertion) => write!(f, "ASSERT {}", assertion.name), + Self::Fault(fault) => write!(f, "FAULT '{fault}'"), + Self::FsyncQuery(query) => write!(f, "{query}"), + Self::FaultyQuery(query) => write!(f, "{query}; -- FAULTY QUERY"), + } + } +} + +type AssertionFunc = dyn Fn(&Vec, &mut SimulatorEnv) -> Result>; + +enum AssertionAST { + Pick(), +} + +pub(crate) struct Assertion { + pub(crate) func: Box, + pub(crate) name: String, // For display purposes in the plan +} + +impl Debug for Assertion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Assertion") + .field("name", &self.name) + .finish() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum Fault { + Disconnect, + ReopenDatabase, +} + +impl Display for Fault { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Fault::Disconnect => write!(f, "DISCONNECT"), + Fault::ReopenDatabase => write!(f, "REOPEN_DATABASE"), + } + } +} + +impl InteractionPlan { + pub(crate) fn new() -> Self { + Self { plan: Vec::new() } + } + + pub(crate) fn stats(&self) -> InteractionStats { + let mut stats = InteractionStats { + read_count: 0, + write_count: 0, + delete_count: 0, + update_count: 0, + create_count: 0, + create_index_count: 0, + drop_count: 0, + begin_count: 0, + commit_count: 0, + rollback_count: 0, + }; + + fn query_stat(q: &Query, stats: &mut InteractionStats) { + match q { + Query::Select(_) => stats.read_count += 1, + Query::Insert(_) => stats.write_count += 1, + Query::Delete(_) => stats.delete_count += 1, + Query::Create(_) => stats.create_count += 1, + Query::Drop(_) => stats.drop_count += 1, + Query::Update(_) => stats.update_count += 1, + Query::CreateIndex(_) => stats.create_index_count += 1, + Query::Begin(_) => stats.begin_count += 1, + Query::Commit(_) => stats.commit_count += 1, + Query::Rollback(_) => stats.rollback_count += 1, + } + } + for interactions in &self.plan { + match interactions { + Interactions::Property(property) => { + for interaction in &property.interactions() { + if let Interaction::Query(query) = interaction { + query_stat(query, &mut stats); + } + } + } + Interactions::Query(query) => { + query_stat(query, &mut stats); + } + Interactions::Fault(_) => {} + } + } + + stats + } +} + +impl ArbitraryFrom<&mut SimulatorEnv> for InteractionPlan { + fn arbitrary_from(rng: &mut R, env: &mut SimulatorEnv) -> Self { + let mut plan = InteractionPlan::new(); + + let num_interactions = env.opts.max_interactions; + + // First create at least one table + let create_query = Create::arbitrary(rng); + env.tables.push(create_query.table.clone()); + + plan.plan + .push(Interactions::Query(Query::Create(create_query))); + + while plan.plan.len() < num_interactions { + tracing::debug!( + "Generating interaction {}/{}", + plan.plan.len(), + num_interactions + ); + let interactions = Interactions::arbitrary_from(rng, (env, plan.stats())); + interactions.shadow(&mut env.tables); + plan.plan.push(interactions); + } + + tracing::info!("Generated plan with {} interactions", plan.plan.len()); + plan + } +} + +impl Interaction { + pub(crate) fn execute_query(&self, conn: &mut Arc, _io: &SimulatorIO) -> ResultSet { + if let Self::Query(query) = self { + let query_str = query.to_string(); + let rows = conn.query(&query_str); + if rows.is_err() { + let err = rows.err(); + tracing::debug!( + "Error running query '{}': {:?}", + &query_str[0..query_str.len().min(4096)], + err + ); + if let Some(turso_core::LimboError::ParseError(e)) = err { + panic!("Unexpected parse error: {e}"); + } + return Err(err.unwrap()); + } + let rows = rows?; + assert!(rows.is_some()); + let mut rows = rows.unwrap(); + let mut out = Vec::new(); + while let Ok(row) = rows.step() { + match row { + StepResult::Row => { + let row = rows.row().unwrap(); + let mut r = Vec::new(); + for v in row.get_values() { + let v = v.into(); + r.push(v); + } + out.push(r); + } + StepResult::IO => { + rows.run_once().unwrap(); + } + StepResult::Interrupt => {} + StepResult::Done => { + break; + } + StepResult::Busy => { + return Err(turso_core::LimboError::Busy); + } + } + } + + Ok(out) + } else { + unreachable!("unexpected: this function should only be called on queries") + } + } + + pub(crate) fn execute_assertion( + &self, + stack: &Vec, + env: &mut SimulatorEnv, + ) -> Result<()> { + match self { + Self::Assertion(assertion) => { + let result = assertion.func.as_ref()(stack, env); + match result { + Ok(Ok(())) => Ok(()), + Ok(Err(message)) => Err(turso_core::LimboError::InternalError(format!( + "Assertion '{}' failed: {}", + assertion.name, message + ))), + Err(err) => Err(turso_core::LimboError::InternalError(format!( + "Assertion '{}' execution error: {}", + assertion.name, err + ))), + } + } + _ => { + unreachable!("unexpected: this function should only be called on assertions") + } + } + } + + pub(crate) fn execute_assumption( + &self, + stack: &Vec, + env: &mut SimulatorEnv, + ) -> Result<()> { + match self { + Self::Assumption(assumption) => { + let result = assumption.func.as_ref()(stack, env); + match result { + Ok(Ok(())) => Ok(()), + Ok(Err(message)) => Err(turso_core::LimboError::InternalError(format!( + "Assumption '{}' failed: {}", + assumption.name, message + ))), + Err(err) => Err(turso_core::LimboError::InternalError(format!( + "Assumption '{}' execution error: {}", + assumption.name, err + ))), + } + } + _ => { + unreachable!("unexpected: this function should only be called on assumptions") + } + } + } + + pub(crate) fn execute_fault(&self, env: &mut SimulatorEnv, conn_index: usize) -> Result<()> { + match self { + Self::Fault(fault) => { + match fault { + Fault::Disconnect => { + if env.connections[conn_index].is_connected() { + env.connections[conn_index].disconnect(); + } else { + return Err(turso_core::LimboError::InternalError( + "connection already disconnected".into(), + )); + } + env.connections[conn_index] = SimConnection::Disconnected; + } + Fault::ReopenDatabase => { + reopen_database(env); + } + } + Ok(()) + } + _ => { + unreachable!("unexpected: this function should only be called on faults") + } + } + } + + pub(crate) fn execute_fsync_query( + &self, + conn: Arc, + env: &mut SimulatorEnv, + ) -> ResultSet { + if let Self::FsyncQuery(query) = self { + let query_str = query.to_string(); + let rows = conn.query(&query_str); + if rows.is_err() { + let err = rows.err(); + tracing::debug!( + "Error running query '{}': {:?}", + &query_str[0..query_str.len().min(4096)], + err + ); + return Err(err.unwrap()); + } + let mut rows = rows.unwrap().unwrap(); + let mut out = Vec::new(); + while let Ok(row) = rows.step() { + match row { + StepResult::Row => { + let row = rows.row().unwrap(); + let mut r = Vec::new(); + for v in row.get_values() { + let v = v.into(); + r.push(v); + } + out.push(r); + } + StepResult::IO => { + let syncing = { + let files = env.io.files.borrow(); + // TODO: currently assuming we only have 1 file that is syncing + files + .iter() + .any(|file| file.sync_completion.borrow().is_some()) + }; + if syncing { + reopen_database(env); + } else { + rows.run_once().unwrap(); + } + } + StepResult::Done => { + break; + } + StepResult::Busy => { + return Err(turso_core::LimboError::Busy); + } + StepResult::Interrupt => {} + } + } + + Ok(out) + } else { + unreachable!("unexpected: this function should only be called on queries") + } + } + + pub(crate) fn execute_faulty_query( + &self, + conn: &Arc, + env: &mut SimulatorEnv, + ) -> ResultSet { + use rand::Rng; + if let Self::FaultyQuery(query) = self { + let query_str = query.to_string(); + let rows = conn.query(&query_str); + if rows.is_err() { + let err = rows.err(); + tracing::debug!( + "Error running query '{}': {:?}", + &query_str[0..query_str.len().min(4096)], + err + ); + if let Some(turso_core::LimboError::ParseError(e)) = err { + panic!("Unexpected parse error: {e}"); + } + return Err(err.unwrap()); + } + let mut rows = rows.unwrap().unwrap(); + let mut out = Vec::new(); + let mut current_prob = 0.05; + let mut incr = 0.001; + loop { + let syncing = { + let files = env.io.files.borrow(); + files + .iter() + .any(|file| file.sync_completion.borrow().is_some()) + }; + let inject_fault = env.rng.gen_bool(current_prob); + // TODO: avoid for now injecting faults when syncing + if inject_fault && !syncing { + env.io.inject_fault(true); + } + + match rows.step()? { + StepResult::Row => { + let row = rows.row().unwrap(); + let mut r = Vec::new(); + for v in row.get_values() { + let v = v.into(); + r.push(v); + } + out.push(r); + } + StepResult::IO => { + rows.run_once()?; + current_prob += incr; + if current_prob > 1.0 { + current_prob = 1.0; + } else { + incr *= 1.01; + } + } + StepResult::Done => { + break; + } + StepResult::Busy => { + return Err(turso_core::LimboError::Busy); + } + StepResult::Interrupt => {} + } + } + + Ok(out) + } else { + unreachable!("unexpected: this function should only be called on queries") + } + } +} + +fn reopen_database(env: &mut SimulatorEnv) { + // 1. Close all connections without default checkpoint-on-close behavior + // to expose bugs related to how we handle WAL + let num_conns = env.connections.len(); + env.connections.clear(); + + // Clear all open files + // TODO: for correct reporting of faults we should get all the recorded numbers and transfer to the new file + env.io.files.borrow_mut().clear(); + + // 2. Re-open database + match env.type_ { + SimulationType::Differential => { + for _ in 0..num_conns { + env.connections.push(SimConnection::SQLiteConnection( + rusqlite::Connection::open(env.get_db_path()) + .expect("Failed to open SQLite connection"), + )); + } + } + SimulationType::Default | SimulationType::Doublecheck => { + env.db = None; + let db = match turso_core::Database::open_file( + env.io.clone(), + env.get_db_path().to_str().expect("path should be 'to_str'"), + false, + true, + ) { + Ok(db) => db, + Err(e) => { + tracing::error!( + "Failed to open database at {}: {}", + env.get_db_path().display(), + e + ); + panic!("Failed to open database: {e}"); + } + }; + + env.db = Some(db); + + for _ in 0..num_conns { + env.connections.push(SimConnection::LimboConnection( + env.db.as_ref().expect("db to be Some").connect().unwrap(), + )); + } + } + }; +} + +fn random_create(rng: &mut R, env: &SimulatorEnv) -> Interactions { + let mut create = Create::arbitrary(rng); + while env.tables.iter().any(|t| t.name == create.table.name) { + create = Create::arbitrary(rng); + } + Interactions::Query(Query::Create(create)) +} + +fn random_read(rng: &mut R, env: &SimulatorEnv) -> Interactions { + Interactions::Query(Query::Select(Select::arbitrary_from(rng, env))) +} + +fn random_expr(rng: &mut R, env: &SimulatorEnv) -> Interactions { + Interactions::Query(Query::Select(SelectFree::arbitrary_from(rng, env).0)) +} + +fn random_write(rng: &mut R, env: &SimulatorEnv) -> Interactions { + Interactions::Query(Query::Insert(Insert::arbitrary_from(rng, env))) +} + +fn random_delete(rng: &mut R, env: &SimulatorEnv) -> Interactions { + Interactions::Query(Query::Delete(Delete::arbitrary_from(rng, env))) +} + +fn random_update(rng: &mut R, env: &SimulatorEnv) -> Interactions { + Interactions::Query(Query::Update(Update::arbitrary_from(rng, env))) +} + +fn random_drop(rng: &mut R, env: &SimulatorEnv) -> Interactions { + Interactions::Query(Query::Drop(Drop::arbitrary_from(rng, env))) +} + +fn random_create_index(rng: &mut R, env: &SimulatorEnv) -> Option { + if env.tables.is_empty() { + return None; + } + let mut create_index = CreateIndex::arbitrary_from(rng, env); + while env + .tables + .iter() + .find(|t| t.name == create_index.table_name) + .expect("table should exist") + .indexes + .iter() + .any(|i| i == &create_index.index_name) + { + create_index = CreateIndex::arbitrary_from(rng, env); + } + + Some(Interactions::Query(Query::CreateIndex(create_index))) +} + +fn random_fault(rng: &mut R, env: &SimulatorEnv) -> Interactions { + let faults = if env.opts.disable_reopen_database { + vec![Fault::Disconnect] + } else { + vec![Fault::Disconnect, Fault::ReopenDatabase] + }; + let fault = faults[rng.random_range(0..faults.len())].clone(); + Interactions::Fault(fault) +} + +impl ArbitraryFrom<(&SimulatorEnv, InteractionStats)> for Interactions { + fn arbitrary_from( + rng: &mut R, + (env, stats): (&SimulatorEnv, InteractionStats), + ) -> Self { + let remaining_ = remaining(env, &stats); + frequency( + vec![ + ( + f64::min(remaining_.read, remaining_.write) + remaining_.create, + Box::new(|rng: &mut R| { + Interactions::Property(Property::arbitrary_from(rng, (env, &stats))) + }), + ), + ( + remaining_.read, + Box::new(|rng: &mut R| random_read(rng, env)), + ), + ( + remaining_.read / 3.0, + Box::new(|rng: &mut R| random_expr(rng, env)), + ), + ( + remaining_.write, + Box::new(|rng: &mut R| random_write(rng, env)), + ), + ( + remaining_.create, + Box::new(|rng: &mut R| random_create(rng, env)), + ), + ( + remaining_.create_index, + Box::new(|rng: &mut R| { + if let Some(interaction) = random_create_index(rng, env) { + interaction + } else { + // if no tables exist, we can't create an index, so fallback to creating a table + random_create(rng, env) + } + }), + ), + ( + remaining_.delete, + Box::new(|rng: &mut R| random_delete(rng, env)), + ), + ( + remaining_.update, + Box::new(|rng: &mut R| random_update(rng, env)), + ), + ( + // remaining_.drop, + 0.0, + Box::new(|rng: &mut R| random_drop(rng, env)), + ), + ( + remaining_ + .read + .min(remaining_.write) + .min(remaining_.create) + .max(1.0), + Box::new(|rng: &mut R| random_fault(rng, env)), + ), + ], + rng, + ) + } +} diff --git a/sql_generation/generation/predicate/binary.rs b/sql_generation/generation/predicate/binary.rs new file mode 100644 index 000000000..29c1727a9 --- /dev/null +++ b/sql_generation/generation/predicate/binary.rs @@ -0,0 +1,586 @@ +//! Contains code for generation for [ast::Expr::Binary] Predicate + +use turso_parser::ast::{self, Expr}; + +use crate::{ + generation::{ + backtrack, one_of, pick, + predicate::{CompoundPredicate, SimplePredicate}, + table::{GTValue, LTValue, LikeValue}, + ArbitraryFrom, ArbitraryFromMaybe as _, + }, + model::{ + query::predicate::Predicate, + table::{SimValue, Table, TableContext}, + }, +}; + +impl Predicate { + /// Generate an [ast::Expr::Binary] [Predicate] from a column and [SimValue] + pub fn from_column_binary( + rng: &mut R, + column_name: &str, + value: &SimValue, + ) -> Predicate { + let expr = one_of( + vec![ + Box::new(|_| { + Expr::Binary( + Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))), + ast::Operator::Equals, + Box::new(Expr::Literal(value.into())), + ) + }), + Box::new(|rng| { + let gt_value = GTValue::arbitrary_from(rng, value).0; + Expr::Binary( + Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))), + ast::Operator::Greater, + Box::new(Expr::Literal(gt_value.into())), + ) + }), + Box::new(|rng| { + let lt_value = LTValue::arbitrary_from(rng, value).0; + Expr::Binary( + Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))), + ast::Operator::Less, + Box::new(Expr::Literal(lt_value.into())), + ) + }), + ], + rng, + ); + Predicate(expr) + } + + /// Produces a true [ast::Expr::Binary] [Predicate] that is true for the provided row in the given table + pub fn true_binary(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate { + // Pick a column + let column_index = rng.random_range(0..t.columns.len()); + let mut column = t.columns[column_index].clone(); + let value = &row[column_index]; + + let mut table_name = t.name.clone(); + if t.name.is_empty() { + // If the table name is empty, we cannot create a qualified expression + // so we use the column name directly + let mut splitted = column.name.split('.'); + table_name = splitted + .next() + .expect("Column name should have a table prefix for a joined table") + .to_string(); + column.name = splitted + .next() + .expect("Column name should have a column suffix for a joined table") + .to_string(); + } + + let expr = backtrack( + vec![ + ( + 1, + Box::new(|_| { + Some(Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::Equals, + Box::new(Expr::Literal(value.into())), + )) + }), + ), + ( + 1, + Box::new(|rng| { + let v = SimValue::arbitrary_from(rng, &column.column_type); + if &v == value { + None + } else { + Some(Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::NotEquals, + Box::new(Expr::Literal(v.into())), + )) + } + }), + ), + ( + 1, + Box::new(|rng| { + let lt_value = LTValue::arbitrary_from(rng, value).0; + Some(Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::Greater, + Box::new(Expr::Literal(lt_value.into())), + )) + }), + ), + ( + 1, + Box::new(|rng| { + let gt_value = GTValue::arbitrary_from(rng, value).0; + Some(Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::Less, + Box::new(Expr::Literal(gt_value.into())), + )) + }), + ), + ( + 1, + Box::new(|rng| { + // TODO: generation for Like and Glob expressions should be extracted to different module + LikeValue::arbitrary_from_maybe(rng, value).map(|like| { + Expr::Like { + lhs: Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + not: false, // TODO: also generate this value eventually + op: ast::LikeOperator::Like, + rhs: Box::new(Expr::Literal(like.0.into())), + escape: None, // TODO: implement + } + }) + }), + ), + ], + rng, + ); + // Backtrack will always return Some here + Predicate(expr.unwrap()) + } + + /// Produces an [ast::Expr::Binary] [Predicate] that is false for the provided row in the given table + pub fn false_binary(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate { + // Pick a column + let column_index = rng.random_range(0..t.columns.len()); + let mut column = t.columns[column_index].clone(); + let mut table_name = t.name.clone(); + let value = &row[column_index]; + + if t.name.is_empty() { + // If the table name is empty, we cannot create a qualified expression + // so we use the column name directly + let mut splitted = column.name.split('.'); + table_name = splitted + .next() + .expect("Column name should have a table prefix for a joined table") + .to_string(); + column.name = splitted + .next() + .expect("Column name should have a column suffix for a joined table") + .to_string(); + } + + let expr = one_of( + vec![ + Box::new(|_| { + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::NotEquals, + Box::new(Expr::Literal(value.into())), + ) + }), + Box::new(|rng| { + let v = loop { + let v = SimValue::arbitrary_from(rng, &column.column_type); + if &v != value { + break v; + } + }; + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::Equals, + Box::new(Expr::Literal(v.into())), + ) + }), + Box::new(|rng| { + let gt_value = GTValue::arbitrary_from(rng, value).0; + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::Greater, + Box::new(Expr::Literal(gt_value.into())), + ) + }), + Box::new(|rng| { + let lt_value = LTValue::arbitrary_from(rng, value).0; + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(&table_name), + ast::Name::new(&column.name), + )), + ast::Operator::Less, + Box::new(Expr::Literal(lt_value.into())), + ) + }), + ], + rng, + ); + Predicate(expr) + } +} + +impl SimplePredicate { + /// Generates a true [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table + pub fn true_binary( + rng: &mut R, + table: &T, + row: &[SimValue], + ) -> Self { + // Pick a random column + let columns = table.columns().collect::>(); + let column_index = rng.random_range(0..columns.len()); + let column = columns[column_index]; + let column_value = &row[column_index]; + let table_name = column.table_name; + // Avoid creation of NULLs + if row.is_empty() { + return SimplePredicate(Predicate(Expr::Literal(SimValue::TRUE.into()))); + } + + let expr = one_of( + vec![ + Box::new(|_rng| { + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(table_name), + ast::Name::new(&column.column.name), + )), + ast::Operator::Equals, + Box::new(Expr::Literal(column_value.into())), + ) + }), + Box::new(|rng| { + let lt_value = LTValue::arbitrary_from(rng, column_value).0; + Expr::Binary( + Box::new(Expr::Qualified( + ast::Name::new(table_name), + ast::Name::new(&column.column.name), + )), + ast::Operator::Greater, + Box::new(Expr::Literal(lt_value.into())), + ) + }), + Box::new(|rng| { + let gt_value = GTValue::arbitrary_from(rng, column_value).0; + Expr::Binary( + Box::new(Expr::Qualified( + ast::Name::new(table_name), + ast::Name::new(&column.column.name), + )), + ast::Operator::Less, + Box::new(Expr::Literal(gt_value.into())), + ) + }), + ], + rng, + ); + SimplePredicate(Predicate(expr)) + } + + /// Generates a false [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table + pub fn false_binary( + rng: &mut R, + table: &T, + row: &[SimValue], + ) -> Self { + let columns = table.columns().collect::>(); + // Pick a random column + let column_index = rng.random_range(0..columns.len()); + let column = columns[column_index]; + let column_value = &row[column_index]; + let table_name = column.table_name; + // Avoid creation of NULLs + if row.is_empty() { + return SimplePredicate(Predicate(Expr::Literal(SimValue::FALSE.into()))); + } + + let expr = one_of( + vec![ + Box::new(|_rng| { + Expr::Binary( + Box::new(Expr::Qualified( + ast::Name::new(table_name), + ast::Name::new(&column.column.name), + )), + ast::Operator::NotEquals, + Box::new(Expr::Literal(column_value.into())), + ) + }), + Box::new(|rng| { + let gt_value = GTValue::arbitrary_from(rng, column_value).0; + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(table_name), + ast::Name::new(&column.column.name), + )), + ast::Operator::Greater, + Box::new(Expr::Literal(gt_value.into())), + ) + }), + Box::new(|rng| { + let lt_value = LTValue::arbitrary_from(rng, column_value).0; + Expr::Binary( + Box::new(ast::Expr::Qualified( + ast::Name::new(table_name), + ast::Name::new(&column.column.name), + )), + ast::Operator::Less, + Box::new(Expr::Literal(lt_value.into())), + ) + }), + ], + rng, + ); + SimplePredicate(Predicate(expr)) + } +} + +impl CompoundPredicate { + /// Decide if you want to create an AND or an OR + /// + /// Creates a Compound Predicate that is TRUE or FALSE for at least a single row + pub fn from_table_binary( + rng: &mut R, + table: &T, + predicate_value: bool, + ) -> Self { + // Cannot pick a row if the table is empty + let rows = table.rows(); + if rows.is_empty() { + return Self(if predicate_value { + Predicate::true_() + } else { + Predicate::false_() + }); + } + let row = pick(rows, rng); + + let predicate = if rng.random_bool(0.7) { + // An AND for true requires each of its children to be true + // An AND for false requires at least one of its children to be false + if predicate_value { + (0..rng.random_range(1..=3)) + .map(|_| SimplePredicate::arbitrary_from(rng, (table, row, true)).0) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::And, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::true_()) + } else { + // Create a vector of random booleans + let mut booleans = (0..rng.random_range(1..=3)) + .map(|_| rng.random_bool(0.5)) + .collect::>(); + + let len = booleans.len(); + + // Make sure at least one of them is false + if booleans.iter().all(|b| *b) { + booleans[rng.random_range(0..len)] = false; + } + + booleans + .iter() + .map(|b| SimplePredicate::arbitrary_from(rng, (table, row, *b)).0) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::And, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::false_()) + } + } else { + // An OR for true requires at least one of its children to be true + // An OR for false requires each of its children to be false + if predicate_value { + // Create a vector of random booleans + let mut booleans = (0..rng.random_range(1..=3)) + .map(|_| rng.random_bool(0.5)) + .collect::>(); + let len = booleans.len(); + // Make sure at least one of them is true + if booleans.iter().all(|b| !*b) { + booleans[rng.random_range(0..len)] = true; + } + + booleans + .iter() + .map(|b| SimplePredicate::arbitrary_from(rng, (table, row, *b)).0) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::Or, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::true_()) + } else { + (0..rng.random_range(1..=3)) + .map(|_| SimplePredicate::arbitrary_from(rng, (table, row, false)).0) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::Or, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::false_()) + } + }; + Self(predicate) + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng as _, SeedableRng as _}; + use rand_chacha::ChaCha8Rng; + + use crate::{ + generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _}, + model::{ + query::predicate::{expr_to_value, Predicate}, + table::{SimValue, Table}, + }, + }; + + fn get_seed() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + } + + #[test] + fn fuzz_true_binary_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + let row = pick(&values, &mut rng); + let predicate = Predicate::true_binary(&mut rng, &table, row); + let value = expr_to_value(&predicate.0, row, &table); + assert!( + value.as_ref().is_some_and(|value| value.as_bool()), + "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" + ) + } + } + + #[test] + fn fuzz_false_binary_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + let row = pick(&values, &mut rng); + let predicate = Predicate::false_binary(&mut rng, &table, row); + let value = expr_to_value(&predicate.0, row, &table); + assert!( + !value.as_ref().is_some_and(|value| value.as_bool()), + "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" + ) + } + } + + #[test] + fn fuzz_true_binary_simple_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let mut table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + table.rows.extend(values.clone()); + let row = pick(&table.rows, &mut rng); + let predicate = SimplePredicate::true_binary(&mut rng, &table, row); + let result = values + .iter() + .map(|row| predicate.0.test(row, &table)) + .reduce(|accum, curr| accum || curr) + .unwrap_or(false); + assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") + } + } + + #[test] + fn fuzz_false_binary_simple_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let mut table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + table.rows.extend(values.clone()); + let row = pick(&table.rows, &mut rng); + let predicate = SimplePredicate::false_binary(&mut rng, &table, row); + let result = values + .iter() + .map(|row| predicate.0.test(row, &table)) + .any(|res| !res); + assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") + } + } +} diff --git a/sql_generation/generation/predicate/mod.rs b/sql_generation/generation/predicate/mod.rs new file mode 100644 index 000000000..0a06dead0 --- /dev/null +++ b/sql_generation/generation/predicate/mod.rs @@ -0,0 +1,378 @@ +use rand::{seq::SliceRandom as _, Rng}; +use turso_parser::ast::{self, Expr}; + +use crate::model::{ + query::predicate::Predicate, + table::{SimValue, Table, TableContext}, +}; + +use super::{one_of, ArbitraryFrom}; + +mod binary; +mod unary; + +#[derive(Debug)] +struct CompoundPredicate(Predicate); + +#[derive(Debug)] +struct SimplePredicate(Predicate); + +impl, T: TableContext> ArbitraryFrom<(&T, A, bool)> for SimplePredicate { + fn arbitrary_from(rng: &mut R, (table, row, predicate_value): (&T, A, bool)) -> Self { + let row = row.as_ref(); + // Pick an operator + let choice = rng.random_range(0..2); + // Pick an operator + match predicate_value { + true => match choice { + 0 => SimplePredicate::true_binary(rng, table, row), + 1 => SimplePredicate::true_unary(rng, table, row), + _ => unreachable!(), + }, + false => match choice { + 0 => SimplePredicate::false_binary(rng, table, row), + 1 => SimplePredicate::false_unary(rng, table, row), + _ => unreachable!(), + }, + } + } +} + +impl ArbitraryFrom<(&T, bool)> for CompoundPredicate { + fn arbitrary_from(rng: &mut R, (table, predicate_value): (&T, bool)) -> Self { + CompoundPredicate::from_table_binary(rng, table, predicate_value) + } +} + +impl ArbitraryFrom<&T> for Predicate { + fn arbitrary_from(rng: &mut R, table: &T) -> Self { + let predicate_value = rng.random_bool(0.5); + Predicate::arbitrary_from(rng, (table, predicate_value)).parens() + } +} + +impl ArbitraryFrom<(&T, bool)> for Predicate { + fn arbitrary_from(rng: &mut R, (table, predicate_value): (&T, bool)) -> Self { + CompoundPredicate::arbitrary_from(rng, (table, predicate_value)).0 + } +} + +impl ArbitraryFrom<(&str, &SimValue)> for Predicate { + fn arbitrary_from(rng: &mut R, (column_name, value): (&str, &SimValue)) -> Self { + Predicate::from_column_binary(rng, column_name, value) + } +} + +impl ArbitraryFrom<(&Table, &Vec)> for Predicate { + fn arbitrary_from(rng: &mut R, (t, row): (&Table, &Vec)) -> Self { + // We want to produce a predicate that is true for the row + // We can do this by creating several predicates that + // are true, some that are false, combiend them in ways that correspond to the creation of a true predicate + + // Produce some true and false predicates + let mut true_predicates = (1..=rng.random_range(1..=4)) + .map(|_| Predicate::true_binary(rng, t, row)) + .collect::>(); + + let false_predicates = (0..=rng.random_range(0..=3)) + .map(|_| Predicate::false_binary(rng, t, row)) + .collect::>(); + + // Start building a top level predicate from a true predicate + let mut result = true_predicates.pop().unwrap(); + + let mut predicates = true_predicates + .iter() + .map(|p| (true, p.clone())) + .chain(false_predicates.iter().map(|p| (false, p.clone()))) + .collect::>(); + + predicates.shuffle(rng); + + while !predicates.is_empty() { + // Create a new predicate from at least 1 and at most 3 predicates + let context = + predicates[0..rng.random_range(0..=usize::min(3, predicates.len()))].to_vec(); + // Shift `predicates` to remove the predicates in the context + predicates = predicates[context.len()..].to_vec(); + + // `result` is true, so we have the following three options to make a true predicate: + // T or F + // T or T + // T and T + + result = one_of( + vec![ + // T or (X1 or X2 or ... or Xn) + Box::new(|_| { + Predicate(Expr::Binary( + Box::new(result.0.clone()), + ast::Operator::Or, + Box::new( + context + .iter() + .map(|(_, p)| p.clone()) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::Or, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::false_()) + .0, + ), + )) + }), + // T or (T1 and T2 and ... and Tn) + Box::new(|_| { + Predicate(Expr::Binary( + Box::new(result.0.clone()), + ast::Operator::Or, + Box::new( + context + .iter() + .map(|(_, p)| p.clone()) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::And, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::true_()) + .0, + ), + )) + }), + // T and T + Box::new(|_| { + // Check if all the predicates in the context are true + if context.iter().all(|(b, _)| *b) { + // T and (X1 or X2 or ... or Xn) + Predicate(Expr::Binary( + Box::new(result.0.clone()), + ast::Operator::And, + Box::new( + context + .iter() + .map(|(_, p)| p.clone()) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::And, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::true_()) + .0, + ), + )) + } + // Check if there is at least one true predicate + else if context.iter().any(|(b, _)| *b) { + // T and (X1 or X2 or ... or Xn) + Predicate(Expr::Binary( + Box::new(result.0.clone()), + ast::Operator::And, + Box::new( + context + .iter() + .map(|(_, p)| p.clone()) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::Or, + Box::new(curr.0), + )) + }) + .unwrap_or(Predicate::false_()) + .0, + ), + )) + // Predicate::And(vec![ + // result.clone(), + // Predicate::Or(context.iter().map(|(_, p)| p.clone()).collect()), + // ]) + } else { + // T and (X1 or X2 or ... or Xn or TRUE) + Predicate(Expr::Binary( + Box::new(result.0.clone()), + ast::Operator::And, + Box::new( + context + .iter() + .map(|(_, p)| p.clone()) + .chain(std::iter::once(Predicate::true_())) + .reduce(|accum, curr| { + Predicate(Expr::Binary( + Box::new(accum.0), + ast::Operator::Or, + Box::new(curr.0), + )) + }) + .unwrap() // Chain guarantees at least one value + .0, + ), + )) + } + }), + ], + rng, + ); + } + result + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng as _, SeedableRng as _}; + use rand_chacha::ChaCha8Rng; + + use crate::{ + generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _}, + model::{ + query::predicate::{expr_to_value, Predicate}, + table::{SimValue, Table}, + }, + }; + + fn get_seed() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + } + + #[test] + fn fuzz_arbitrary_table_true_simple_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + let row = pick(&values, &mut rng); + let predicate = SimplePredicate::arbitrary_from(&mut rng, (&table, row, true)).0; + let value = expr_to_value(&predicate.0, row, &table); + assert!( + value.as_ref().is_some_and(|value| value.as_bool()), + "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" + ) + } + } + + #[test] + fn fuzz_arbitrary_table_false_simple_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + let row = pick(&values, &mut rng); + let predicate = SimplePredicate::arbitrary_from(&mut rng, (&table, row, false)).0; + let value = expr_to_value(&predicate.0, row, &table); + assert!( + !value.as_ref().is_some_and(|value| value.as_bool()), + "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" + ) + } + } + + #[test] + fn fuzz_arbitrary_row_table_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + let row = pick(&values, &mut rng); + let predicate = Predicate::arbitrary_from(&mut rng, (&table, row)); + let value = expr_to_value(&predicate.0, row, &table); + assert!( + value.as_ref().is_some_and(|value| value.as_bool()), + "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" + ) + } + } + + #[test] + fn fuzz_arbitrary_true_table_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let mut table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + table.rows.extend(values.clone()); + let predicate = Predicate::arbitrary_from(&mut rng, (&table, true)); + let result = values + .iter() + .map(|row| predicate.test(row, &table)) + .reduce(|accum, curr| accum || curr) + .unwrap_or(false); + assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") + } + } + + #[test] + fn fuzz_arbitrary_false_table_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let mut table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + table.rows.extend(values.clone()); + let predicate = Predicate::arbitrary_from(&mut rng, (&table, false)); + let result = values + .iter() + .map(|row| predicate.test(row, &table)) + .any(|res| !res); + assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") + } + } +} diff --git a/sql_generation/generation/predicate/unary.rs b/sql_generation/generation/predicate/unary.rs new file mode 100644 index 000000000..6800740d7 --- /dev/null +++ b/sql_generation/generation/predicate/unary.rs @@ -0,0 +1,306 @@ +//! Contains code regarding generation for [ast::Expr::Unary] Predicate +//! TODO: for now just generating [ast::Literal], but want to also generate Columns and any +//! arbitrary [ast::Expr] + +use turso_parser::ast::{self, Expr}; + +use crate::{ + generation::{backtrack, pick, predicate::SimplePredicate, ArbitraryFromMaybe}, + model::{ + query::predicate::Predicate, + table::{SimValue, TableContext}, + }, +}; + +pub struct TrueValue(pub SimValue); + +impl ArbitraryFromMaybe<&SimValue> for TrueValue { + fn arbitrary_from_maybe(_rng: &mut R, value: &SimValue) -> Option + where + Self: Sized, + { + // If the Value is a true value return it else you cannot return a true Value + value.as_bool().then_some(Self(value.clone())) + } +} + +impl ArbitraryFromMaybe<&Vec<&SimValue>> for TrueValue { + fn arbitrary_from_maybe(rng: &mut R, values: &Vec<&SimValue>) -> Option + where + Self: Sized, + { + if values.is_empty() { + return Some(Self(SimValue::TRUE)); + } + + let value = pick(values, rng); + Self::arbitrary_from_maybe(rng, *value) + } +} + +pub struct FalseValue(pub SimValue); + +impl ArbitraryFromMaybe<&SimValue> for FalseValue { + fn arbitrary_from_maybe(_rng: &mut R, value: &SimValue) -> Option + where + Self: Sized, + { + // If the Value is a false value return it else you cannot return a false Value + (!value.as_bool()).then_some(Self(value.clone())) + } +} + +impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue { + fn arbitrary_from_maybe(rng: &mut R, values: &Vec<&SimValue>) -> Option + where + Self: Sized, + { + if values.is_empty() { + return Some(Self(SimValue::FALSE)); + } + + let value = pick(values, rng); + Self::arbitrary_from_maybe(rng, *value) + } +} + +pub struct BitNotValue(pub SimValue); + +impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue { + fn arbitrary_from_maybe( + _rng: &mut R, + (value, predicate): (&SimValue, bool), + ) -> Option + where + Self: Sized, + { + let bit_not_val = value.unary_exec(ast::UnaryOperator::BitwiseNot); + // If you bit not the Value and it meets the predicate return Some, else None + (bit_not_val.as_bool() == predicate).then_some(BitNotValue(value.clone())) + } +} + +impl ArbitraryFromMaybe<(&Vec<&SimValue>, bool)> for BitNotValue { + fn arbitrary_from_maybe( + rng: &mut R, + (values, predicate): (&Vec<&SimValue>, bool), + ) -> Option + where + Self: Sized, + { + if values.is_empty() { + return None; + } + + let value = pick(values, rng); + Self::arbitrary_from_maybe(rng, (*value, predicate)) + } +} + +// TODO: have some more complex generation with columns names here as well +impl SimplePredicate { + /// Generates a true [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for some values in the table + pub fn true_unary( + rng: &mut R, + table: &T, + row: &[SimValue], + ) -> Self { + let columns = table.columns().collect::>(); + // Pick a random column + let column_index = rng.random_range(0..columns.len()); + let column_value = &row[column_index]; + let num_retries = row.len(); + // Avoid creation of NULLs + if row.is_empty() { + return SimplePredicate(Predicate(Expr::Literal(SimValue::TRUE.into()))); + } + let expr = backtrack( + vec![ + ( + num_retries, + Box::new(|rng| { + TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| { + assert!(value.0.as_bool()); + // Positive is a no-op in Sqlite + Expr::unary(ast::UnaryOperator::Positive, Expr::Literal(value.0.into())) + }) + }), + ), + // ( + // num_retries, + // Box::new(|rng| { + // TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| { + // assert!(value.0.as_bool()); + // // True Value with negative is still True + // Expr::unary(ast::UnaryOperator::Negative, Expr::Literal(value.0.into())) + // }) + // }), + // ), + // ( + // num_retries, + // Box::new(|rng| { + // BitNotValue::arbitrary_from_maybe(rng, (column_value, true)).map(|value| { + // Expr::unary( + // ast::UnaryOperator::BitwiseNot, + // Expr::Literal(value.0.into()), + // ) + // }) + // }), + // ), + ( + num_retries, + Box::new(|rng| { + FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| { + assert!(!value.0.as_bool()); + Expr::unary(ast::UnaryOperator::Not, Expr::Literal(value.0.into())) + }) + }), + ), + ], + rng, + ); + // If cannot generate a value + SimplePredicate(Predicate( + expr.unwrap_or(Expr::Literal(SimValue::TRUE.into())), + )) + } + + /// Generates a false [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for a row in the table + pub fn false_unary( + rng: &mut R, + table: &T, + row: &[SimValue], + ) -> Self { + let columns = table.columns().collect::>(); + // Pick a random column + let column_index = rng.random_range(0..columns.len()); + let column_value = &row[column_index]; + let num_retries = row.len(); + // Avoid creation of NULLs + if row.is_empty() { + return SimplePredicate(Predicate(Expr::Literal(SimValue::FALSE.into()))); + } + let expr = backtrack( + vec![ + // ( + // num_retries, + // Box::new(|rng| { + // FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| { + // assert!(!value.0.as_bool()); + // // Positive is a no-op in Sqlite + // Expr::unary(ast::UnaryOperator::Positive, Expr::Literal(value.0.into())) + // }) + // }), + // ), + // ( + // num_retries, + // Box::new(|rng| { + // FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| { + // assert!(!value.0.as_bool()); + // // True Value with negative is still True + // Expr::unary(ast::UnaryOperator::Negative, Expr::Literal(value.0.into())) + // }) + // }), + // ), + // ( + // num_retries, + // Box::new(|rng| { + // BitNotValue::arbitrary_from_maybe(rng, (column_value, false)).map(|value| { + // Expr::unary( + // ast::UnaryOperator::BitwiseNot, + // Expr::Literal(value.0.into()), + // ) + // }) + // }), + // ), + ( + num_retries, + Box::new(|rng| { + TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| { + assert!(value.0.as_bool()); + Expr::unary(ast::UnaryOperator::Not, Expr::Literal(value.0.into())) + }) + }), + ), + ], + rng, + ); + // If cannot generate a value + SimplePredicate(Predicate( + expr.unwrap_or(Expr::Literal(SimValue::FALSE.into())), + )) + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng as _, SeedableRng as _}; + use rand_chacha::ChaCha8Rng; + + use crate::{ + generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _}, + model::table::{SimValue, Table}, + }; + + fn get_seed() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + } + + #[test] + fn fuzz_true_unary_simple_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let mut table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + table.rows.extend(values.clone()); + let row = pick(&table.rows, &mut rng); + let predicate = SimplePredicate::true_unary(&mut rng, &table, row); + let result = values + .iter() + .map(|row| predicate.0.test(row, &table)) + .reduce(|accum, curr| accum || curr) + .unwrap_or(false); + assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") + } + } + + #[test] + fn fuzz_false_unary_simple_predicate() { + let seed = get_seed(); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + for _ in 0..10000 { + let mut table = Table::arbitrary(&mut rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) + .collect() + }) + .collect(); + table.rows.extend(values.clone()); + let row = pick(&table.rows, &mut rng); + let predicate = SimplePredicate::false_unary(&mut rng, &table, row); + let result = values + .iter() + .map(|row| predicate.0.test(row, &table)) + .any(|res| !res); + assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") + } + } +} diff --git a/sql_generation/generation/property.rs b/sql_generation/generation/property.rs new file mode 100644 index 000000000..495f75ec7 --- /dev/null +++ b/sql_generation/generation/property.rs @@ -0,0 +1,1533 @@ +use serde::{Deserialize, Serialize}; +use turso_core::{types, LimboError}; +use turso_parser::ast::{self}; + +use crate::{ + model::{ + query::{ + predicate::Predicate, + select::{ + CompoundOperator, CompoundSelect, Distinctness, ResultColumn, SelectBody, + SelectInner, + }, + transaction::{Begin, Commit, Rollback}, + update::Update, + Create, Delete, Drop, Insert, Query, Select, + }, + table::SimValue, + }, + runner::env::SimulatorEnv, +}; + +use super::{ + frequency, pick, pick_index, + plan::{Assertion, Interaction, InteractionStats, ResultSet}, + ArbitraryFrom, +}; + +/// Properties are representations of executable specifications +/// about the database behavior. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum Property { + /// Insert-Select is a property in which the inserted row + /// must be in the resulting rows of a select query that has a + /// where clause that matches the inserted row. + /// The execution of the property is as follows + /// INSERT INTO VALUES (...) + /// I_0 + /// I_1 + /// ... + /// I_n + /// SELECT * FROM WHERE + /// The interactions in the middle has the following constraints; + /// - There will be no errors in the middle interactions. + /// - The inserted row will not be deleted. + /// - The inserted row will not be updated. + /// - The table `t` will not be renamed, dropped, or altered. + InsertValuesSelect { + /// The insert query + insert: Insert, + /// Selected row index + row_index: usize, + /// Additional interactions in the middle of the property + queries: Vec, + /// The select query + select: Select, + /// Interactive query information if any + interactive: Option, + }, + /// ReadYourUpdatesBack is a property in which the updated rows + /// must be in the resulting rows of a select query that has a + /// where clause that matches the updated row. + /// The execution of the property is as follows + /// UPDATE SET WHERE + /// SELECT FROM WHERE + /// These interactions are executed in immediate succession + /// just to verify the property that our updates did what they + /// were supposed to do. + ReadYourUpdatesBack { + update: Update, + select: Select, + }, + /// TableHasExpectedContent is a property in which the table + /// must have the expected content, i.e. all the insertions and + /// updates and deletions should have been persisted in the way + /// we think they were. + /// The execution of the property is as follows + /// SELECT * FROM + /// ASSERT + TableHasExpectedContent { + table: String, + }, + /// Double Create Failure is a property in which creating + /// the same table twice leads to an error. + /// The execution of the property is as follows + /// CREATE TABLE (...) + /// I_0 + /// I_1 + /// ... + /// I_n + /// CREATE TABLE (...) -> Error + /// The interactions in the middle has the following constraints; + /// - There will be no errors in the middle interactions. + /// - Table `t` will not be renamed or dropped. + DoubleCreateFailure { + /// The create query + create: Create, + /// Additional interactions in the middle of the property + queries: Vec, + }, + /// Select Limit is a property in which the select query + /// has a limit clause that is respected by the query. + /// The execution of the property is as follows + /// SELECT * FROM WHERE LIMIT + /// This property is a single-interaction property. + /// The interaction has the following constraints; + /// - The select query will respect the limit clause. + SelectLimit { + /// The select query + select: Select, + }, + /// Delete-Select is a property in which the deleted row + /// must not be in the resulting rows of a select query that has a + /// where clause that matches the deleted row. In practice, `p1` of + /// the delete query will be used as the predicate for the select query, + /// hence the select should return NO ROWS. + /// The execution of the property is as follows + /// DELETE FROM WHERE + /// I_0 + /// I_1 + /// ... + /// I_n + /// SELECT * FROM WHERE + /// The interactions in the middle has the following constraints; + /// - There will be no errors in the middle interactions. + /// - A row that holds for the predicate will not be inserted. + /// - The table `t` will not be renamed, dropped, or altered. + DeleteSelect { + table: String, + predicate: Predicate, + queries: Vec, + }, + /// Drop-Select is a property in which selecting from a dropped table + /// should result in an error. + /// The execution of the property is as follows + /// DROP TABLE + /// I_0 + /// I_1 + /// ... + /// I_n + /// SELECT * FROM WHERE -> Error + /// The interactions in the middle has the following constraints; + /// - There will be no errors in the middle interactions. + /// - The table `t` will not be created, no table will be renamed to `t`. + DropSelect { + table: String, + queries: Vec, + select: Select, + }, + /// Select-Select-Optimizer is a property in which we test the optimizer by + /// running two equivalent select queries, one with `SELECT from ` + /// and the other with `SELECT * from WHERE `. As highlighted by + /// Rigger et al. in Non-Optimizing Reference Engine Construction(NoREC), SQLite + /// tends to optimize `where` statements while keeping the result column expressions + /// unoptimized. This property is used to test the optimizer. The property is successful + /// if the two queries return the same number of rows. + SelectSelectOptimizer { + table: String, + predicate: Predicate, + }, + /// Where-True-False-Null is a property that tests the boolean logic implementation + /// in the database. It relies on the fact that `P == true || P == false || P == null` should return true, + /// as SQLite uses a ternary logic system. This property is invented in "Finding Bugs in Database Systems via Query Partitioning" + /// by Rigger et al. and it is canonically called Ternary Logic Partitioning (TLP). + WhereTrueFalseNull { + select: Select, + predicate: Predicate, + }, + /// UNION-ALL-Preserves-Cardinality is a property that tests the UNION ALL operator + /// implementation in the database. It relies on the fact that `SELECT * FROM WHERE UNION ALL SELECT * FROM WHERE ` + /// should return the same number of rows as `SELECT FROM WHERE `. + /// > The property is succesfull when the UNION ALL of 2 select queries returns the same number of rows + /// > as the sum of the two select queries. + UNIONAllPreservesCardinality { + select: Select, + where_clause: Predicate, + }, + /// FsyncNoWait is a property which tests if we do not loose any data after not waiting for fsync. + /// + /// # Interactions + /// - Executes the `query` without waiting for fsync + /// - Drop all connections and Reopen the database + /// - Execute the `query` again + /// - Query tables to assert that the values were inserted + /// + FsyncNoWait { + query: Query, + tables: Vec, + }, + FaultyQuery { + query: Query, + tables: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InteractiveQueryInfo { + start_with_immediate: bool, + end_with_commit: bool, +} + +impl Property { + pub(crate) fn name(&self) -> &str { + match self { + Property::InsertValuesSelect { .. } => "Insert-Values-Select", + Property::ReadYourUpdatesBack { .. } => "Read-Your-Updates-Back", + Property::TableHasExpectedContent { .. } => "Table-Has-Expected-Content", + Property::DoubleCreateFailure { .. } => "Double-Create-Failure", + Property::SelectLimit { .. } => "Select-Limit", + Property::DeleteSelect { .. } => "Delete-Select", + Property::DropSelect { .. } => "Drop-Select", + Property::SelectSelectOptimizer { .. } => "Select-Select-Optimizer", + Property::WhereTrueFalseNull { .. } => "Where-True-False-Null", + Property::FsyncNoWait { .. } => "FsyncNoWait", + Property::FaultyQuery { .. } => "FaultyQuery", + Property::UNIONAllPreservesCardinality { .. } => "UNION-All-Preserves-Cardinality", + } + } + /// interactions construct a list of interactions, which is an executable representation of the property. + /// the requirement of property -> vec conversion emerges from the need to serialize the property, + /// and `interaction` cannot be serialized directly. + pub(crate) fn interactions(&self) -> Vec { + match self { + Property::TableHasExpectedContent { table } => { + let table = table.to_string(); + let table_name = table.clone(); + let assumption = Interaction::Assumption(Assertion { + name: format!("table {} exists", table.clone()), + func: Box::new(move |_: &Vec, env: &mut SimulatorEnv| { + if env.tables.iter().any(|t| t.name == table_name) { + Ok(Ok(())) + } else { + Ok(Err(format!("table {table_name} does not exist"))) + } + }), + }); + + let select_interaction = Interaction::Query(Query::Select(Select::simple( + table.clone(), + Predicate::true_(), + ))); + + let assertion = Interaction::Assertion(Assertion { + name: format!("table {} should have the expected content", table.clone()), + func: Box::new(move |stack: &Vec, env| { + let rows = stack.last().unwrap(); + let Ok(rows) = rows else { + return Ok(Err(format!("expected rows but got error: {rows:?}"))); + }; + let sim_table = env + .tables + .iter() + .find(|t| t.name == table) + .expect("table should be in enviroment"); + if rows.len() != sim_table.rows.len() { + return Ok(Err(format!( + "expected {} rows but got {} for table {}", + sim_table.rows.len(), + rows.len(), + table.clone() + ))); + } + for expected_row in sim_table.rows.iter() { + if !rows.contains(expected_row) { + return Ok(Err(format!( + "expected row {:?} not found in table {}", + expected_row, + table.clone() + ))); + } + } + Ok(Ok(())) + }), + }); + + vec![assumption, select_interaction, assertion] + } + Property::ReadYourUpdatesBack { update, select } => { + let table = update.table().to_string(); + let assumption = Interaction::Assumption(Assertion { + name: format!("table {} exists", table.clone()), + func: Box::new(move |_: &Vec, env: &mut SimulatorEnv| { + if env.tables.iter().any(|t| t.name == table.clone()) { + Ok(Ok(())) + } else { + Ok(Err(format!("table {} does not exist", table.clone()))) + } + }), + }); + + let update_interaction = Interaction::Query(Query::Update(update.clone())); + let select_interaction = Interaction::Query(Query::Select(select.clone())); + + let update = update.clone(); + + let table = update.table().to_string(); + + let assertion = Interaction::Assertion(Assertion { + name: format!( + "updated rows should be found and have the updated values for table {}", + table.clone() + ), + func: Box::new(move |stack: &Vec, _| { + let rows = stack.last().unwrap(); + match rows { + Ok(rows) => { + for row in rows { + for (i, (col, val)) in update.set_values.iter().enumerate() { + if &row[i] != val { + return Ok(Err(format!("updated row {} has incorrect value for column {col}: expected {val}, got {}", i, row[i]))); + } + } + } + Ok(Ok(())) + } + Err(err) => Err(LimboError::InternalError(err.to_string())), + } + }), + }); + + vec![ + assumption, + update_interaction, + select_interaction, + assertion, + ] + } + Property::InsertValuesSelect { + insert, + row_index, + queries, + select, + interactive, + } => { + let (table, values) = if let Insert::Values { table, values } = insert { + (table, values) + } else { + unreachable!( + "insert query should be Insert::Values for Insert-Values-Select property" + ) + }; + // Check that the insert query has at least 1 value + assert!( + !values.is_empty(), + "insert query should have at least 1 value" + ); + + // Pick a random row within the insert values + let row = values[*row_index].clone(); + + // Assume that the table exists + let assumption = Interaction::Assumption(Assertion { + name: format!("table {} exists", insert.table()), + func: Box::new({ + let table_name = table.clone(); + move |_: &Vec, env: &mut SimulatorEnv| { + if env.tables.iter().any(|t| t.name == table_name) { + Ok(Ok(())) + } else { + Ok(Err(format!("table {table_name} does not exist"))) + } + } + }), + }); + + let assertion = Interaction::Assertion(Assertion { + name: format!( + "row [{:?}] should be found in table {}, interactive={} commit={}, rollback={}", + row.iter().map(|v| v.to_string()).collect::>(), + insert.table(), + interactive.is_some(), + interactive + .as_ref() + .map(|i| i.end_with_commit) + .unwrap_or(false), + interactive + .as_ref() + .map(|i| !i.end_with_commit) + .unwrap_or(false), + ), + func: Box::new(move |stack: &Vec, _| { + let rows = stack.last().unwrap(); + match rows { + Ok(rows) => { + let found = rows.iter().any(|r| r == &row); + if found { + Ok(Ok(())) + } else { + Ok(Err(format!("row [{:?}] not found in table", row.iter().map(|v| v.to_string()).collect::>()))) + } + } + Err(err) => Err(LimboError::InternalError(err.to_string())), + } + }), + }); + + let mut interactions = Vec::new(); + interactions.push(assumption); + interactions.push(Interaction::Query(Query::Insert(insert.clone()))); + interactions.extend(queries.clone().into_iter().map(Interaction::Query)); + interactions.push(Interaction::Query(Query::Select(select.clone()))); + interactions.push(assertion); + + interactions + } + Property::DoubleCreateFailure { create, queries } => { + let table_name = create.table.name.clone(); + + let assumption = Interaction::Assumption(Assertion { + name: "Double-Create-Failure should not be called on an existing table" + .to_string(), + func: Box::new(move |_: &Vec, env: &mut SimulatorEnv| { + if !env.tables.iter().any(|t| t.name == table_name) { + Ok(Ok(())) + } else { + Ok(Err(format!("table {table_name} already exists"))) + } + }), + }); + + let cq1 = Interaction::Query(Query::Create(create.clone())); + let cq2 = Interaction::Query(Query::Create(create.clone())); + + let table_name = create.table.name.clone(); + + let assertion = Interaction::Assertion(Assertion { + name: + "creating two tables with the name should result in a failure for the second query" + .to_string(), + func: Box::new(move |stack: &Vec, _| { + let last = stack.last().unwrap(); + match last { + Ok(success) => Ok(Err(format!("expected table creation to fail but it succeeded: {success:?}"))), + Err(e) => { + if e.to_string().to_lowercase().contains(&format!("table {table_name} already exists")) { + Ok(Ok(())) + } else { + Ok(Err(format!("expected table already exists error, got: {e}"))) + } + } + } + }), + }); + + let mut interactions = Vec::new(); + interactions.push(assumption); + interactions.push(cq1); + interactions.extend(queries.clone().into_iter().map(Interaction::Query)); + interactions.push(cq2); + interactions.push(assertion); + + interactions + } + Property::SelectLimit { select } => { + let assumption = Interaction::Assumption(Assertion { + name: format!( + "table ({}) exists", + select + .dependencies() + .into_iter() + .collect::>() + .join(", ") + ), + func: Box::new({ + let table_name = select.dependencies(); + move |_: &Vec, env: &mut SimulatorEnv| { + if table_name + .iter() + .all(|table| env.tables.iter().any(|t| t.name == *table)) + { + Ok(Ok(())) + } else { + let missing_tables = table_name + .iter() + .filter(|t| !env.tables.iter().any(|t2| t2.name == **t)) + .collect::>(); + Ok(Err(format!("missing tables: {missing_tables:?}"))) + } + } + }), + }); + + let limit = select + .limit + .expect("Property::SelectLimit without a LIMIT clause"); + + let assertion = Interaction::Assertion(Assertion { + name: "select query should respect the limit clause".to_string(), + func: Box::new(move |stack: &Vec, _| { + let last = stack.last().unwrap(); + match last { + Ok(rows) => { + if limit >= rows.len() { + Ok(Ok(())) + } else { + Ok(Err(format!( + "limit {} violated: got {} rows", + limit, + rows.len() + ))) + } + } + Err(_) => Ok(Ok(())), + } + }), + }); + + vec![ + assumption, + Interaction::Query(Query::Select(select.clone())), + assertion, + ] + } + Property::DeleteSelect { + table, + predicate, + queries, + } => { + let assumption = Interaction::Assumption(Assertion { + name: format!("table {table} exists"), + func: Box::new({ + let table = table.clone(); + move |_: &Vec, env: &mut SimulatorEnv| { + if env.tables.iter().any(|t| t.name == table) { + Ok(Ok(())) + } else { + { + let available_tables: Vec = + env.tables.iter().map(|t| t.name.clone()).collect(); + Ok(Err(format!( + "table \'{table}\' not found. Available tables: {available_tables:?}" + ))) + } + } + } + }), + }); + + let delete = Interaction::Query(Query::Delete(Delete { + table: table.clone(), + predicate: predicate.clone(), + })); + + let select = Interaction::Query(Query::Select(Select::simple( + table.clone(), + predicate.clone(), + ))); + + let assertion = Interaction::Assertion(Assertion { + name: format!("`{select}` should return no values for table `{table}`",), + func: Box::new(move |stack: &Vec, _| { + let rows = stack.last().unwrap(); + match rows { + Ok(rows) => { + if rows.is_empty() { + Ok(Ok(())) + } else { + Ok(Err(format!( + "expected no rows but got {} rows: {:?}", + rows.len(), + rows.iter() + .map(|r| print_row(r)) + .collect::>() + .join(", ") + ))) + } + } + Err(err) => Err(LimboError::InternalError(err.to_string())), + } + }), + }); + + let mut interactions = Vec::new(); + interactions.push(assumption); + interactions.push(delete); + interactions.extend(queries.clone().into_iter().map(Interaction::Query)); + interactions.push(select); + interactions.push(assertion); + + interactions + } + Property::DropSelect { + table, + queries, + select, + } => { + let assumption = Interaction::Assumption(Assertion { + name: format!("table {table} exists"), + func: Box::new({ + let table = table.clone(); + move |_, env: &mut SimulatorEnv| { + if env.tables.iter().any(|t| t.name == table) { + Ok(Ok(())) + } else { + { + let available_tables: Vec = + env.tables.iter().map(|t| t.name.clone()).collect(); + Ok(Err(format!( + "table \'{table}\' not found. Available tables: {available_tables:?}" + ))) + } + } + } + }), + }); + + let table_name = table.clone(); + + let assertion = Interaction::Assertion(Assertion { + name: format!("select query should result in an error for table '{table}'"), + func: Box::new(move |stack: &Vec, _| { + let last = stack.last().unwrap(); + match last { + Ok(success) => Ok(Err(format!( + "expected table creation to fail but it succeeded: {success:?}" + ))), + Err(e) => { + if e.to_string() + .contains(&format!("Table {table_name} does not exist")) + { + Ok(Ok(())) + } else { + Ok(Err(format!( + "expected table does not exist error, got: {e}" + ))) + } + } + } + }), + }); + + let drop = Interaction::Query(Query::Drop(Drop { + table: table.clone(), + })); + + let select = Interaction::Query(Query::Select(select.clone())); + + let mut interactions = Vec::new(); + + interactions.push(assumption); + interactions.push(drop); + interactions.extend(queries.clone().into_iter().map(Interaction::Query)); + interactions.push(select); + interactions.push(assertion); + + interactions + } + Property::SelectSelectOptimizer { table, predicate } => { + let assumption = Interaction::Assumption(Assertion { + name: format!("table {table} exists"), + func: Box::new({ + let table = table.clone(); + move |_: &Vec, env: &mut SimulatorEnv| { + if env.tables.iter().any(|t| t.name == table) { + Ok(Ok(())) + } else { + { + let available_tables: Vec = + env.tables.iter().map(|t| t.name.clone()).collect(); + Ok(Err(format!( + "table \'{table}\' not found. Available tables: {available_tables:?}" + ))) + } + } + } + }), + }); + + let select1 = Interaction::Query(Query::Select(Select::single( + table.clone(), + vec![ResultColumn::Expr(predicate.clone())], + Predicate::true_(), + None, + Distinctness::All, + ))); + + let select2_query = Query::Select(Select::simple(table.clone(), predicate.clone())); + + let select2 = Interaction::Query(select2_query); + + let assertion = Interaction::Assertion(Assertion { + name: "select queries should return the same amount of results".to_string(), + func: Box::new(move |stack: &Vec, _| { + let select_star = stack.last().unwrap(); + let select_predicate = stack.get(stack.len() - 2).unwrap(); + match (select_predicate, select_star) { + (Ok(rows1), Ok(rows2)) => { + // If rows1 results have more than 1 column, there is a problem + if rows1.iter().any(|vs| vs.len() > 1) { + return Err(LimboError::InternalError( + "Select query without the star should return only one column".to_string(), + )); + } + // Count the 1s in the select query without the star + let rows1_count = rows1 + .iter() + .filter(|vs| { + let v = vs.first().unwrap(); + v.as_bool() + }) + .count(); + tracing::debug!( + "select1 returned {} rows, select2 returned {} rows", + rows1_count, + rows2.len() + ); + if rows1_count == rows2.len() { + Ok(Ok(())) + } else { + Ok(Err(format!( + "row counts don't match: {} vs {}", + rows1_count, + rows2.len() + ))) + } + } + (Err(e1), Err(e2)) => { + tracing::debug!("Error in select1 AND select2: {}, {}", e1, e2); + Ok(Ok(())) + } + (Err(e), _) | (_, Err(e)) => { + tracing::error!("Error in select1 OR select2: {}", e); + Err(LimboError::InternalError(e.to_string())) + } + } + }), + }); + + vec![assumption, select1, select2, assertion] + } + Property::FsyncNoWait { query, tables } => { + let checks = assert_all_table_values(tables); + Vec::from_iter( + std::iter::once(Interaction::FsyncQuery(query.clone())).chain(checks), + ) + } + Property::FaultyQuery { query, tables } => { + let checks = assert_all_table_values(tables); + let query_clone = query.clone(); + let assert = Assertion { + // A fault may not occur as we first signal we want a fault injected, + // then when IO is called the fault triggers. It may happen that a fault is injected + // but no IO happens right after it + name: "fault occured".to_string(), + func: Box::new(move |stack, env: &mut SimulatorEnv| { + let last = stack.last().unwrap(); + match last { + Ok(_) => { + let _ = query_clone.shadow(&mut env.tables); + Ok(Ok(())) + } + Err(err) => { + // We cannot make any assumptions about the error content; all we are about is, if the statement errored, + // we don't shadow the results into the simulator env, i.e. we assume whatever the statement did was rolled back. + tracing::error!("Fault injection produced error: {err}"); + Ok(Ok(())) + } + } + }), + }; + let first = [ + Interaction::FaultyQuery(query.clone()), + Interaction::Assertion(assert), + ] + .into_iter(); + Vec::from_iter(first.chain(checks)) + } + Property::WhereTrueFalseNull { select, predicate } => { + let assumption = Interaction::Assumption(Assertion { + name: format!( + "tables ({}) exists", + select + .dependencies() + .into_iter() + .collect::>() + .join(", ") + ), + func: Box::new({ + let tables = select.dependencies(); + move |_: &Vec, env: &mut SimulatorEnv| { + if tables + .iter() + .all(|table| env.tables.iter().any(|t| t.name == *table)) + { + Ok(Ok(())) + } else { + let missing_tables = tables + .iter() + .filter(|t| !env.tables.iter().any(|t2| t2.name == **t)) + .collect::>(); + Ok(Err(format!("missing tables: {missing_tables:?}"))) + } + } + }), + }); + + let old_predicate = select.body.select.where_clause.clone(); + + let p_true = Predicate::and(vec![old_predicate.clone(), predicate.clone()]); + let p_false = Predicate::and(vec![ + old_predicate.clone(), + Predicate::not(predicate.clone()), + ]); + let p_null = Predicate::and(vec![ + old_predicate.clone(), + Predicate::is(predicate.clone(), Predicate::null()), + ]); + + let select_tlp = Select { + body: SelectBody { + select: Box::new(SelectInner { + distinctness: select.body.select.distinctness, + columns: select.body.select.columns.clone(), + from: select.body.select.from.clone(), + where_clause: p_true, + order_by: None, + }), + compounds: vec![ + CompoundSelect { + operator: CompoundOperator::UnionAll, + select: Box::new(SelectInner { + distinctness: select.body.select.distinctness, + columns: select.body.select.columns.clone(), + from: select.body.select.from.clone(), + where_clause: p_false, + order_by: None, + }), + }, + CompoundSelect { + operator: CompoundOperator::UnionAll, + select: Box::new(SelectInner { + distinctness: select.body.select.distinctness, + columns: select.body.select.columns.clone(), + from: select.body.select.from.clone(), + where_clause: p_null, + order_by: None, + }), + }, + ], + }, + limit: None, + }; + + let select = Interaction::Query(Query::Select(select.clone())); + let select_tlp = Interaction::Query(Query::Select(select_tlp)); + + // select and select_tlp should return the same rows + let assertion = Interaction::Assertion(Assertion { + name: "select and select_tlp should return the same rows".to_string(), + func: Box::new(move |stack: &Vec, _: &mut SimulatorEnv| { + if stack.len() < 2 { + return Err(LimboError::InternalError( + "Not enough result sets on the stack".to_string(), + )); + } + + let select_result_set = stack.get(stack.len() - 2).unwrap(); + let select_tlp_result_set = stack.last().unwrap(); + + match (select_result_set, select_tlp_result_set) { + (Ok(select_rows), Ok(select_tlp_rows)) => { + if select_rows.len() != select_tlp_rows.len() { + return Ok(Err(format!("row count mismatch: select returned {} rows, select_tlp returned {} rows", select_rows.len(), select_tlp_rows.len()))); + } + // Check if any row in select_rows is not in select_tlp_rows + for row in select_rows.iter() { + if !select_tlp_rows.iter().any(|r| r == row) { + tracing::debug!( + "select and select_tlp returned different rows, ({}) is in select but not in select_tlp", + row.iter().map(|v| v.to_string()).collect::>().join(", ") + ); + return Ok(Err(format!( + "row mismatch: row [{}] exists in select results but not in select_tlp results", + print_row(row) + ))); + } + } + // Check if any row in select_tlp_rows is not in select_rows + for row in select_tlp_rows.iter() { + if !select_rows.iter().any(|r| r == row) { + tracing::debug!( + "select and select_tlp returned different rows, ({}) is in select_tlp but not in select", + row.iter().map(|v| v.to_string()).collect::>().join(", ") + ); + + return Ok(Err(format!( + "row mismatch: row [{}] exists in select_tlp but not in select", + print_row(row) + ))); + } + } + // If we reach here, the rows are the same + tracing::trace!( + "select and select_tlp returned the same rows: {:?}", + select_rows + ); + + Ok(Ok(())) + } + (Err(e), _) | (_, Err(e)) => { + tracing::error!("Error in select or select_tlp: {}", e); + Err(LimboError::InternalError(e.to_string())) + } + } + }), + }); + + vec![assumption, select, select_tlp, assertion] + } + Property::UNIONAllPreservesCardinality { + select, + where_clause, + } => { + let s1 = select.clone(); + let mut s2 = select.clone(); + s2.body.select.where_clause = where_clause.clone(); + let s3 = Select::compound(s1.clone(), s2.clone(), CompoundOperator::UnionAll); + + vec![ + Interaction::Query(Query::Select(s1.clone())), + Interaction::Query(Query::Select(s2.clone())), + Interaction::Query(Query::Select(s3.clone())), + Interaction::Assertion(Assertion { + name: "UNION ALL should preserve cardinality".to_string(), + func: Box::new(move |stack: &Vec, _: &mut SimulatorEnv| { + if stack.len() < 3 { + return Err(LimboError::InternalError( + "Not enough result sets on the stack".to_string(), + )); + } + + let select1 = stack.get(stack.len() - 3).unwrap(); + let select2 = stack.get(stack.len() - 2).unwrap(); + let union_all = stack.last().unwrap(); + + match (select1, select2, union_all) { + (Ok(rows1), Ok(rows2), Ok(union_rows)) => { + let count1 = rows1.len(); + let count2 = rows2.len(); + let union_count = union_rows.len(); + if union_count == count1 + count2 { + Ok(Ok(())) + } else { + Ok(Err(format!("UNION ALL should preserve cardinality but it didn't: {count1} + {count2} != {union_count}"))) + } + } + (Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => { + tracing::error!("Error in select queries: {}", e); + Err(LimboError::InternalError(e.to_string())) + } + } + }), + }), + ] + } + } + } +} + +fn assert_all_table_values(tables: &[String]) -> impl Iterator + use<'_> { + let checks = tables.iter().flat_map(|table| { + let select = Interaction::Query(Query::Select(Select::simple( + table.clone(), + Predicate::true_(), + ))); + + let assertion = Interaction::Assertion(Assertion { + name: format!("table {table} should contain all of its expected values"), + func: Box::new({ + let table = table.clone(); + move |stack: &Vec, env: &mut SimulatorEnv| { + let table = env.tables.iter().find(|t| t.name == table).ok_or_else(|| { + LimboError::InternalError(format!( + "table {table} should exist in simulator env" + )) + })?; + let last = stack.last().unwrap(); + match last { + Ok(vals) => { + // Check if all values in the table are present in the result set + // Find a value in the table that is not in the result set + let model_contains_db = table.rows.iter().find(|v| { + !vals.iter().any(|r| { + &r == v + }) + }); + let db_contains_model = vals.iter().find(|v| { + !table.rows.iter().any(|r| &r == v) + }); + + if let Some(model_contains_db) = model_contains_db { + tracing::debug!( + "table {} does not contain the expected values, the simulator model has more rows than the database: {:?}", + table.name, + print_row(model_contains_db) + ); + Ok(Err(format!("table {} does not contain the expected values, the simulator model has more rows than the database: {:?}", table.name, print_row(model_contains_db)))) + } else if let Some(db_contains_model) = db_contains_model { + tracing::debug!( + "table {} does not contain the expected values, the database has more rows than the simulator model: {:?}", + table.name, + print_row(db_contains_model) + ); + Ok(Err(format!("table {} does not contain the expected values, the database has more rows than the simulator model: {:?}", table.name, print_row(db_contains_model)))) + } else { + Ok(Ok(())) + } + } + Err(err) => Err(LimboError::InternalError(format!("{err}"))), + } + } + }), + }); + [select, assertion].into_iter() + }); + checks +} + +#[derive(Debug)] +pub(crate) struct Remaining { + pub(crate) read: f64, + pub(crate) write: f64, + pub(crate) create: f64, + pub(crate) create_index: f64, + pub(crate) delete: f64, + pub(crate) update: f64, + pub(crate) drop: f64, +} + +pub(crate) fn remaining(env: &SimulatorEnv, stats: &InteractionStats) -> Remaining { + let remaining_read = ((env.opts.max_interactions as f64 * env.opts.read_percent / 100.0) + - (stats.read_count as f64)) + .max(0.0); + let remaining_write = ((env.opts.max_interactions as f64 * env.opts.write_percent / 100.0) + - (stats.write_count as f64)) + .max(0.0); + let remaining_create = ((env.opts.max_interactions as f64 * env.opts.create_percent / 100.0) + - (stats.create_count as f64)) + .max(0.0); + + let remaining_create_index = + ((env.opts.max_interactions as f64 * env.opts.create_index_percent / 100.0) + - (stats.create_index_count as f64)) + .max(0.0); + + let remaining_delete = ((env.opts.max_interactions as f64 * env.opts.delete_percent / 100.0) + - (stats.delete_count as f64)) + .max(0.0); + let remaining_update = ((env.opts.max_interactions as f64 * env.opts.update_percent / 100.0) + - (stats.update_count as f64)) + .max(0.0); + let remaining_drop = ((env.opts.max_interactions as f64 * env.opts.drop_percent / 100.0) + - (stats.drop_count as f64)) + .max(0.0); + + Remaining { + read: remaining_read, + write: remaining_write, + create: remaining_create, + create_index: remaining_create_index, + delete: remaining_delete, + drop: remaining_drop, + update: remaining_update, + } +} + +fn property_insert_values_select( + rng: &mut R, + env: &SimulatorEnv, + remaining: &Remaining, +) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Generate rows to insert + let rows = (0..rng.random_range(1..=5)) + .map(|_| Vec::::arbitrary_from(rng, table)) + .collect::>(); + + // Pick a random row to select + let row_index = pick_index(rows.len(), rng); + let row = rows[row_index].clone(); + + // Insert the rows + let insert_query = Insert::Values { + table: table.name.clone(), + values: rows, + }; + + // Choose if we want queries to be executed in an interactive transaction + let interactive = if rng.random_bool(0.5) { + Some(InteractiveQueryInfo { + start_with_immediate: rng.random_bool(0.5), + end_with_commit: rng.random_bool(0.5), + }) + } else { + None + }; + // Create random queries respecting the constraints + let mut queries = Vec::new(); + // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) + // - [x] The inserted row will not be deleted. + // - [x] The inserted row will not be updated. + // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) + if let Some(ref interactive) = interactive { + queries.push(Query::Begin(Begin { + immediate: interactive.start_with_immediate, + })); + } + for _ in 0..rng.random_range(0..3) { + let query = Query::arbitrary_from(rng, (env, remaining)); + match &query { + Query::Delete(Delete { + table: t, + predicate, + }) => { + // The inserted row will not be deleted. + if t == &table.name && predicate.test(&row, table) { + continue; + } + } + Query::Create(Create { table: t }) => { + // There will be no errors in the middle interactions. + // - Creating the same table is an error + if t.name == table.name { + continue; + } + } + Query::Update(Update { + table: t, + set_values: _, + predicate, + }) => { + // The inserted row will not be updated. + if t == &table.name && predicate.test(&row, table) { + continue; + } + } + _ => (), + } + queries.push(query); + } + if let Some(ref interactive) = interactive { + queries.push(if interactive.end_with_commit { + Query::Commit(Commit) + } else { + Query::Rollback(Rollback) + }); + } + + // Select the row + let select_query = Select::simple( + table.name.clone(), + Predicate::arbitrary_from(rng, (table, &row)), + ); + + Property::InsertValuesSelect { + insert: insert_query, + row_index, + queries, + select: select_query, + interactive, + } +} + +fn property_read_your_updates_back(rng: &mut R, env: &SimulatorEnv) -> Property { + // e.g. UPDATE t SET a=1, b=2 WHERE c=1; + let update = Update::arbitrary_from(rng, env); + // e.g. SELECT a, b FROM t WHERE c=1; + let select = Select::single( + update.table().to_string(), + update + .set_values + .iter() + .map(|(col, _)| ResultColumn::Column(col.clone())) + .collect(), + update.predicate.clone(), + None, + Distinctness::All, + ); + + Property::ReadYourUpdatesBack { update, select } +} + +fn property_table_has_expected_content(rng: &mut R, env: &SimulatorEnv) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + Property::TableHasExpectedContent { + table: table.name.clone(), + } +} + +fn property_select_limit(rng: &mut R, env: &SimulatorEnv) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Select the table + let select = Select::single( + table.name.clone(), + vec![ResultColumn::Star], + Predicate::arbitrary_from(rng, table), + Some(rng.random_range(1..=5)), + Distinctness::All, + ); + Property::SelectLimit { select } +} + +fn property_double_create_failure( + rng: &mut R, + env: &SimulatorEnv, + remaining: &Remaining, +) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Create the table + let create_query = Create { + table: table.clone(), + }; + + // Create random queries respecting the constraints + let mut queries = Vec::new(); + // The interactions in the middle has the following constraints; + // - [x] There will be no errors in the middle interactions.(best effort) + // - [ ] Table `t` will not be renamed or dropped.(todo: add this constraint once ALTER or DROP is implemented) + for _ in 0..rng.random_range(0..3) { + let query = Query::arbitrary_from(rng, (env, remaining)); + if let Query::Create(Create { table: t }) = &query { + // There will be no errors in the middle interactions. + // - Creating the same table is an error + if t.name == table.name { + continue; + } + } + queries.push(query); + } + + Property::DoubleCreateFailure { + create: create_query, + queries, + } +} + +fn property_delete_select( + rng: &mut R, + env: &SimulatorEnv, + remaining: &Remaining, +) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Generate a random predicate + let predicate = Predicate::arbitrary_from(rng, table); + + // Create random queries respecting the constraints + let mut queries = Vec::new(); + // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) + // - [x] A row that holds for the predicate will not be inserted. + // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) + for _ in 0..rng.random_range(0..3) { + let query = Query::arbitrary_from(rng, (env, remaining)); + match &query { + Query::Insert(Insert::Values { table: t, values }) => { + // A row that holds for the predicate will not be inserted. + if t == &table.name && values.iter().any(|v| predicate.test(v, table)) { + continue; + } + } + Query::Insert(Insert::Select { + table: t, + select: _, + }) => { + // A row that holds for the predicate will not be inserted. + if t == &table.name { + continue; + } + } + Query::Update(Update { table: t, .. }) => { + // A row that holds for the predicate will not be updated. + if t == &table.name { + continue; + } + } + Query::Create(Create { table: t }) => { + // There will be no errors in the middle interactions. + // - Creating the same table is an error + if t.name == table.name { + continue; + } + } + _ => (), + } + queries.push(query); + } + + Property::DeleteSelect { + table: table.name.clone(), + predicate, + queries, + } +} + +fn property_drop_select( + rng: &mut R, + env: &SimulatorEnv, + remaining: &Remaining, +) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + + // Create random queries respecting the constraints + let mut queries = Vec::new(); + // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) + // - [-] The table `t` will not be created, no table will be renamed to `t`. (todo: update this constraint once ALTER is implemented) + for _ in 0..rng.random_range(0..3) { + let query = Query::arbitrary_from(rng, (env, remaining)); + if let Query::Create(Create { table: t }) = &query { + // - The table `t` will not be created + if t.name == table.name { + continue; + } + } + queries.push(query); + } + + let select = Select::simple(table.name.clone(), Predicate::arbitrary_from(rng, table)); + + Property::DropSelect { + table: table.name.clone(), + queries, + select, + } +} + +fn property_select_select_optimizer(rng: &mut R, env: &SimulatorEnv) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Generate a random predicate + let predicate = Predicate::arbitrary_from(rng, table); + // Transform into a Binary predicate to force values to be casted to a bool + let expr = ast::Expr::Binary( + Box::new(predicate.0), + ast::Operator::And, + Box::new(Predicate::true_().0), + ); + + Property::SelectSelectOptimizer { + table: table.name.clone(), + predicate: Predicate(expr), + } +} + +fn property_where_true_false_null(rng: &mut R, env: &SimulatorEnv) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Generate a random predicate + let p1 = Predicate::arbitrary_from(rng, table); + let p2 = Predicate::arbitrary_from(rng, table); + + // Create the select query + let select = Select::simple(table.name.clone(), p1); + + Property::WhereTrueFalseNull { + select, + predicate: p2, + } +} + +fn property_union_all_preserves_cardinality( + rng: &mut R, + env: &SimulatorEnv, +) -> Property { + // Get a random table + let table = pick(&env.tables, rng); + // Generate a random predicate + let p1 = Predicate::arbitrary_from(rng, table); + let p2 = Predicate::arbitrary_from(rng, table); + + // Create the select query + let select = Select::single( + table.name.clone(), + vec![ResultColumn::Star], + p1, + None, + Distinctness::All, + ); + + Property::UNIONAllPreservesCardinality { + select, + where_clause: p2, + } +} + +fn property_fsync_no_wait( + rng: &mut R, + env: &SimulatorEnv, + remaining: &Remaining, +) -> Property { + Property::FsyncNoWait { + query: Query::arbitrary_from(rng, (env, remaining)), + tables: env.tables.iter().map(|t| t.name.clone()).collect(), + } +} + +fn property_faulty_query( + rng: &mut R, + env: &SimulatorEnv, + remaining: &Remaining, +) -> Property { + Property::FaultyQuery { + query: Query::arbitrary_from(rng, (env, remaining)), + tables: env.tables.iter().map(|t| t.name.clone()).collect(), + } +} + +impl ArbitraryFrom<(&SimulatorEnv, &InteractionStats)> for Property { + fn arbitrary_from( + rng: &mut R, + (env, stats): (&SimulatorEnv, &InteractionStats), + ) -> Self { + let remaining_ = remaining(env, stats); + + frequency( + vec![ + ( + if !env.opts.disable_insert_values_select { + f64::min(remaining_.read, remaining_.write) + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_insert_values_select(rng, env, &remaining_)), + ), + ( + remaining_.read, + Box::new(|rng: &mut R| property_table_has_expected_content(rng, env)), + ), + ( + f64::min(remaining_.read, remaining_.write), + Box::new(|rng: &mut R| property_read_your_updates_back(rng, env)), + ), + ( + if !env.opts.disable_double_create_failure { + remaining_.create / 2.0 + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_double_create_failure(rng, env, &remaining_)), + ), + ( + if !env.opts.disable_select_limit { + remaining_.read + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_select_limit(rng, env)), + ), + ( + if !env.opts.disable_delete_select { + f64::min(remaining_.read, remaining_.write).min(remaining_.delete) + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_delete_select(rng, env, &remaining_)), + ), + ( + if !env.opts.disable_drop_select { + // remaining_.drop + 0.0 + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_drop_select(rng, env, &remaining_)), + ), + ( + if !env.opts.disable_select_optimizer { + remaining_.read / 2.0 + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_select_select_optimizer(rng, env)), + ), + ( + if env.opts.experimental_indexes && !env.opts.disable_where_true_false_null { + remaining_.read / 2.0 + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_where_true_false_null(rng, env)), + ), + ( + if env.opts.experimental_indexes + && !env.opts.disable_union_all_preserves_cardinality + { + remaining_.read / 3.0 + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_union_all_preserves_cardinality(rng, env)), + ), + ( + if !env.opts.disable_fsync_no_wait { + 50.0 // Freestyle number + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_fsync_no_wait(rng, env, &remaining_)), + ), + ( + if !env.opts.disable_faulty_query { + 20.0 + } else { + 0.0 + }, + Box::new(|rng: &mut R| property_faulty_query(rng, env, &remaining_)), + ), + ], + rng, + ) + } +} + +fn print_row(row: &[SimValue]) -> String { + row.iter() + .map(|v| match &v.0 { + types::Value::Null => "NULL".to_string(), + types::Value::Integer(i) => i.to_string(), + types::Value::Float(f) => f.to_string(), + types::Value::Text(t) => t.to_string(), + types::Value::Blob(b) => format!( + "X'{}'", + b.iter() + .fold(String::new(), |acc, b| acc + &format!("{b:02X}")) + ), + }) + .collect::>() + .join(", ") +} diff --git a/sql_generation/generation/query.rs b/sql_generation/generation/query.rs new file mode 100644 index 000000000..eff24613c --- /dev/null +++ b/sql_generation/generation/query.rs @@ -0,0 +1,447 @@ +use crate::generation::{ + gen_random_text, pick_n_unique, Arbitrary, ArbitraryFrom, ArbitrarySizedFrom, +}; +use crate::model::query::predicate::Predicate; +use crate::model::query::select::{ + CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, + SelectInner, +}; +use crate::model::query::update::Update; +use crate::model::query::{Create, CreateIndex, Delete, Drop, Insert, Query, Select}; +use crate::model::table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}; +use crate::SimulatorEnv; +use itertools::Itertools; +use rand::Rng; +use turso_parser::ast::{Expr, SortOrder}; + +use super::property::Remaining; +use super::{backtrack, frequency, pick}; + +impl Arbitrary for Create { + fn arbitrary(rng: &mut R) -> Self { + Create { + table: Table::arbitrary(rng), + } + } +} + +impl ArbitraryFrom<&Vec> for FromClause { + fn arbitrary_from(rng: &mut R, tables: &Vec
) -> Self { + let num_joins = match rng.random_range(0..=100) { + 0..=90 => 0, + 91..=97 => 1, + 98..=100 => 2, + _ => unreachable!(), + }; + + let mut tables = tables.clone(); + let mut table = pick(&tables, rng).clone(); + + tables.retain(|t| t.name != table.name); + + let name = table.name.clone(); + + let mut table_context = JoinTable { + tables: Vec::new(), + rows: Vec::new(), + }; + + let joins: Vec<_> = (0..num_joins) + .filter_map(|_| { + if tables.is_empty() { + return None; + } + let join_table = pick(&tables, rng).clone(); + let joined_table_name = join_table.name.clone(); + + tables.retain(|t| t.name != join_table.name); + table_context.rows = table_context + .rows + .iter() + .cartesian_product(join_table.rows.iter()) + .map(|(t_row, j_row)| { + let mut row = t_row.clone(); + row.extend(j_row.clone()); + row + }) + .collect(); + // TODO: inneficient. use a Deque to push_front? + table_context.tables.insert(0, join_table); + for row in &mut table.rows { + assert_eq!( + row.len(), + table.columns.len(), + "Row length does not match column length after join" + ); + } + + let predicate = Predicate::arbitrary_from(rng, &table); + Some(JoinedTable { + table: joined_table_name, + join_type: JoinType::Inner, + on: predicate, + }) + }) + .collect(); + FromClause { table: name, joins } + } +} + +impl ArbitraryFrom<&SimulatorEnv> for SelectInner { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + let from = FromClause::arbitrary_from(rng, &env.tables); + let mut tables = env.tables.clone(); + // todo: this is a temporary hack because env is not separated from the tables + let join_table = from + .shadow(&mut tables) + .expect("Failed to shadow FromClause"); + let cuml_col_count = join_table.columns().count(); + + let order_by = 'order_by: { + if rng.random_bool(0.3) { + let order_by_table_candidates = from + .joins + .iter() + .map(|j| j.table.clone()) + .chain(std::iter::once(from.table.clone())) + .collect::>(); + let order_by_col_count = + (rng.random::() * rng.random::() * (cuml_col_count as f64)) as usize; // skew towards 0 + if order_by_col_count == 0 { + break 'order_by None; + } + let mut col_names = std::collections::HashSet::new(); + let mut order_by_cols = Vec::new(); + while order_by_cols.len() < order_by_col_count { + let table = pick(&order_by_table_candidates, rng); + let table = tables.iter().find(|t| t.name == *table).unwrap(); + let col = pick(&table.columns, rng); + let col_name = format!("{}.{}", table.name, col.name); + if col_names.insert(col_name.clone()) { + order_by_cols.push(( + col_name, + if rng.random_bool(0.5) { + SortOrder::Asc + } else { + SortOrder::Desc + }, + )); + } + } + Some(OrderBy { + columns: order_by_cols, + }) + } else { + None + } + }; + + SelectInner { + distinctness: if env.opts.experimental_indexes { + Distinctness::arbitrary(rng) + } else { + Distinctness::All + }, + columns: vec![ResultColumn::Star], + from: Some(from), + where_clause: Predicate::arbitrary_from(rng, &join_table), + order_by, + } + } +} + +impl ArbitrarySizedFrom<&SimulatorEnv> for SelectInner { + fn arbitrary_sized_from( + rng: &mut R, + env: &SimulatorEnv, + num_result_columns: usize, + ) -> Self { + let mut select_inner = SelectInner::arbitrary_from(rng, env); + let select_from = &select_inner.from.as_ref().unwrap(); + let table_names = select_from + .joins + .iter() + .map(|j| j.table.clone()) + .chain(std::iter::once(select_from.table.clone())) + .collect::>(); + + let flat_columns_names = table_names + .iter() + .flat_map(|t| { + env.tables + .iter() + .find(|table| table.name == *t) + .unwrap() + .columns + .iter() + .map(|c| format!("{}.{}", t.clone(), c.name)) + }) + .collect::>(); + let selected_columns = pick_unique(&flat_columns_names, num_result_columns, rng); + let mut columns = Vec::new(); + for column_name in selected_columns { + columns.push(ResultColumn::Column(column_name.clone())); + } + select_inner.columns = columns; + select_inner + } +} + +impl Arbitrary for Distinctness { + fn arbitrary(rng: &mut R) -> Self { + match rng.random_range(0..=5) { + 0..4 => Distinctness::All, + _ => Distinctness::Distinct, + } + } +} +impl Arbitrary for CompoundOperator { + fn arbitrary(rng: &mut R) -> Self { + match rng.random_range(0..=1) { + 0 => CompoundOperator::Union, + 1 => CompoundOperator::UnionAll, + _ => unreachable!(), + } + } +} + +/// SelectFree is a wrapper around Select that allows for arbitrary generation +/// of selects without requiring a specific environment, which is useful for generating +/// arbitrary expressions without referring to the tables. +pub(crate) struct SelectFree(pub(crate) Select); + +impl ArbitraryFrom<&SimulatorEnv> for SelectFree { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + let expr = Predicate(Expr::arbitrary_sized_from(rng, env, 8)); + let select = Select::expr(expr); + Self(select) + } +} + +impl ArbitraryFrom<&SimulatorEnv> for Select { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + // Generate a number of selects based on the query size + // If experimental indexes are enabled, we can have selects with compounds + // Otherwise, we just have a single select with no compounds + let num_compound_selects = if env.opts.experimental_indexes { + match rng.random_range(0..=100) { + 0..=95 => 0, + 96..=99 => 1, + 100 => 2, + _ => unreachable!(), + } + } else { + 0 + }; + + let min_column_count_across_tables = + env.tables.iter().map(|t| t.columns.len()).min().unwrap(); + + let num_result_columns = rng.random_range(1..=min_column_count_across_tables); + + let mut first = SelectInner::arbitrary_sized_from(rng, env, num_result_columns); + + let mut rest: Vec = (0..num_compound_selects) + .map(|_| SelectInner::arbitrary_sized_from(rng, env, num_result_columns)) + .collect(); + + if !rest.is_empty() { + // ORDER BY is not supported in compound selects yet + first.order_by = None; + for s in &mut rest { + s.order_by = None; + } + } + + Self { + body: SelectBody { + select: Box::new(first), + compounds: rest + .into_iter() + .map(|s| CompoundSelect { + operator: CompoundOperator::arbitrary(rng), + select: Box::new(s), + }) + .collect(), + }, + limit: None, + } + } +} + +impl ArbitraryFrom<&SimulatorEnv> for Insert { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + let gen_values = |rng: &mut R| { + let table = pick(&env.tables, rng); + let num_rows = rng.random_range(1..10); + let values: Vec> = (0..num_rows) + .map(|_| { + table + .columns + .iter() + .map(|c| SimValue::arbitrary_from(rng, &c.column_type)) + .collect() + }) + .collect(); + Some(Insert::Values { + table: table.name.clone(), + values, + }) + }; + + let _gen_select = |rng: &mut R| { + // Find a non-empty table + let select_table = env.tables.iter().find(|t| !t.rows.is_empty())?; + let row = pick(&select_table.rows, rng); + let predicate = Predicate::arbitrary_from(rng, (select_table, row)); + // Pick another table to insert into + let select = Select::simple(select_table.name.clone(), predicate); + let table = pick(&env.tables, rng); + Some(Insert::Select { + table: table.name.clone(), + select: Box::new(select), + }) + }; + + // TODO: Add back gen_select when https://github.com/tursodatabase/turso/issues/2129 is fixed. + // Backtrack here cannot return None + backtrack(vec![(1, Box::new(gen_values))], rng).unwrap() + } +} + +impl ArbitraryFrom<&SimulatorEnv> for Delete { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + let table = pick(&env.tables, rng); + Self { + table: table.name.clone(), + predicate: Predicate::arbitrary_from(rng, table), + } + } +} + +impl ArbitraryFrom<&SimulatorEnv> for Drop { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + let table = pick(&env.tables, rng); + Self { + table: table.name.clone(), + } + } +} + +impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + assert!( + !env.tables.is_empty(), + "Cannot create an index when no tables exist in the environment." + ); + + let table = pick(&env.tables, rng); + + if table.columns.is_empty() { + panic!( + "Cannot create an index on table '{}' as it has no columns.", + table.name + ); + } + + let num_columns_to_pick = rng.random_range(1..=table.columns.len()); + let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng); + + let columns = picked_column_indices + .into_iter() + .map(|i| { + let column = &table.columns[i]; + ( + column.name.clone(), + if rng.random_bool(0.5) { + SortOrder::Asc + } else { + SortOrder::Desc + }, + ) + }) + .collect::>(); + + let index_name = format!( + "idx_{}_{}", + table.name, + gen_random_text(rng).chars().take(8).collect::() + ); + + CreateIndex { + index_name, + table_name: table.name.clone(), + columns, + } + } +} + +impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query { + fn arbitrary_from(rng: &mut R, (env, remaining): (&SimulatorEnv, &Remaining)) -> Self { + frequency( + vec![ + ( + remaining.create, + Box::new(|rng| Self::Create(Create::arbitrary(rng))), + ), + ( + remaining.read, + Box::new(|rng| Self::Select(Select::arbitrary_from(rng, env))), + ), + ( + remaining.write, + Box::new(|rng| Self::Insert(Insert::arbitrary_from(rng, env))), + ), + ( + remaining.update, + Box::new(|rng| Self::Update(Update::arbitrary_from(rng, env))), + ), + ( + f64::min(remaining.write, remaining.delete), + Box::new(|rng| Self::Delete(Delete::arbitrary_from(rng, env))), + ), + ], + rng, + ) + } +} + +fn pick_unique( + items: &[T], + count: usize, + rng: &mut impl rand::Rng, +) -> Vec +where + ::Owned: PartialEq, +{ + let mut picked: Vec = Vec::new(); + while picked.len() < count { + let item = pick(items, rng); + if !picked.contains(&item.to_owned()) { + picked.push(item.to_owned()); + } + } + picked +} + +impl ArbitraryFrom<&SimulatorEnv> for Update { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + let table = pick(&env.tables, rng); + let num_cols = rng.random_range(1..=table.columns.len()); + let columns = pick_unique(&table.columns, num_cols, rng); + let set_values: Vec<(String, SimValue)> = columns + .iter() + .map(|column| { + ( + column.name.clone(), + SimValue::arbitrary_from(rng, &column.column_type), + ) + }) + .collect(); + Update { + table: table.name.clone(), + set_values, + predicate: Predicate::arbitrary_from(rng, table), + } + } +} diff --git a/sql_generation/generation/table.rs b/sql_generation/generation/table.rs new file mode 100644 index 000000000..fdddb6ff2 --- /dev/null +++ b/sql_generation/generation/table.rs @@ -0,0 +1,258 @@ +use std::collections::HashSet; + +use rand::Rng; +use turso_core::Value; + +use crate::generation::{gen_random_text, pick, readable_name_custom, Arbitrary, ArbitraryFrom}; +use crate::model::table::{Column, ColumnType, Name, SimValue, Table}; + +use super::ArbitraryFromMaybe; + +impl Arbitrary for Name { + fn arbitrary(rng: &mut R) -> Self { + let name = readable_name_custom("_", rng); + Name(name.replace("-", "_")) + } +} + +impl Arbitrary for Table { + fn arbitrary(rng: &mut R) -> Self { + let name = Name::arbitrary(rng).0; + let columns = loop { + let large_table = rng.random_bool(0.1); + let column_size = if large_table { + rng.random_range(64..125) // todo: make this higher (128+) + } else { + rng.random_range(1..=10) + }; + let columns = (1..=column_size) + .map(|_| Column::arbitrary(rng)) + .collect::>(); + // TODO: see if there is a better way to detect duplicates here + let mut set = HashSet::with_capacity(columns.len()); + set.extend(columns.iter()); + // Has repeated column name inside so generate again + if set.len() != columns.len() { + continue; + } + break columns; + }; + + Table { + rows: Vec::new(), + name, + columns, + indexes: vec![], + } + } +} + +impl Arbitrary for Column { + fn arbitrary(rng: &mut R) -> Self { + let name = Name::arbitrary(rng).0; + let column_type = ColumnType::arbitrary(rng); + Self { + name, + column_type, + primary: false, + unique: false, + } + } +} + +impl Arbitrary for ColumnType { + fn arbitrary(rng: &mut R) -> Self { + pick(&[Self::Integer, Self::Float, Self::Text, Self::Blob], rng).to_owned() + } +} + +impl ArbitraryFrom<&Table> for Vec { + fn arbitrary_from(rng: &mut R, table: &Table) -> Self { + let mut row = Vec::new(); + for column in table.columns.iter() { + let value = SimValue::arbitrary_from(rng, &column.column_type); + row.push(value); + } + row + } +} + +impl ArbitraryFrom<&Vec<&SimValue>> for SimValue { + fn arbitrary_from(rng: &mut R, values: &Vec<&Self>) -> Self { + if values.is_empty() { + return Self(Value::Null); + } + + pick(values, rng).to_owned().clone() + } +} + +impl ArbitraryFrom<&ColumnType> for SimValue { + fn arbitrary_from(rng: &mut R, column_type: &ColumnType) -> Self { + let value = match column_type { + ColumnType::Integer => Value::Integer(rng.random_range(i64::MIN..i64::MAX)), + ColumnType::Float => Value::Float(rng.random_range(-1e10..1e10)), + ColumnType::Text => Value::build_text(gen_random_text(rng)), + ColumnType::Blob => Value::Blob(gen_random_text(rng).as_bytes().to_vec()), + }; + SimValue(value) + } +} + +pub(crate) struct LTValue(pub(crate) SimValue); + +impl ArbitraryFrom<&Vec<&SimValue>> for LTValue { + fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { + if values.is_empty() { + return Self(SimValue(Value::Null)); + } + + // Get value less than all values + let value = Value::exec_min(values.iter().map(|value| &value.0)); + Self::arbitrary_from(rng, &SimValue(value)) + } +} + +impl ArbitraryFrom<&SimValue> for LTValue { + fn arbitrary_from(rng: &mut R, value: &SimValue) -> Self { + let new_value = match &value.0 { + Value::Integer(i) => Value::Integer(rng.random_range(i64::MIN..*i - 1)), + Value::Float(f) => Value::Float(f - rng.random_range(0.0..1e10)), + value @ Value::Text(..) => { + // Either shorten the string, or make at least one character smaller and mutate the rest + let mut t = value.to_string(); + if rng.random_bool(0.01) { + t.pop(); + Value::build_text(t) + } else { + let mut t = t.chars().map(|c| c as u32).collect::>(); + let index = rng.random_range(0..t.len()); + t[index] -= 1; + // Mutate the rest of the string + for val in t.iter_mut().skip(index + 1) { + *val = rng.random_range('a' as u32..='z' as u32); + } + let t = t + .into_iter() + .map(|c| char::from_u32(c).unwrap_or('z')) + .collect::(); + Value::build_text(t) + } + } + Value::Blob(b) => { + // Either shorten the blob, or make at least one byte smaller and mutate the rest + let mut b = b.clone(); + if rng.random_bool(0.01) { + b.pop(); + Value::Blob(b) + } else { + let index = rng.random_range(0..b.len()); + b[index] -= 1; + // Mutate the rest of the blob + for val in b.iter_mut().skip(index + 1) { + *val = rng.random_range(0..=255); + } + Value::Blob(b) + } + } + _ => unreachable!(), + }; + Self(SimValue(new_value)) + } +} + +pub(crate) struct GTValue(pub(crate) SimValue); + +impl ArbitraryFrom<&Vec<&SimValue>> for GTValue { + fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { + if values.is_empty() { + return Self(SimValue(Value::Null)); + } + // Get value greater than all values + let value = Value::exec_max(values.iter().map(|value| &value.0)); + + Self::arbitrary_from(rng, &SimValue(value)) + } +} + +impl ArbitraryFrom<&SimValue> for GTValue { + fn arbitrary_from(rng: &mut R, value: &SimValue) -> Self { + let new_value = match &value.0 { + Value::Integer(i) => Value::Integer(rng.random_range(*i..i64::MAX)), + Value::Float(f) => Value::Float(rng.random_range(*f..1e10)), + value @ Value::Text(..) => { + // Either lengthen the string, or make at least one character smaller and mutate the rest + let mut t = value.to_string(); + if rng.random_bool(0.01) { + t.push(rng.random_range(0..=255) as u8 as char); + Value::build_text(t) + } else { + let mut t = t.chars().map(|c| c as u32).collect::>(); + let index = rng.random_range(0..t.len()); + t[index] += 1; + // Mutate the rest of the string + for val in t.iter_mut().skip(index + 1) { + *val = rng.random_range('a' as u32..='z' as u32); + } + let t = t + .into_iter() + .map(|c| char::from_u32(c).unwrap_or('a')) + .collect::(); + Value::build_text(t) + } + } + Value::Blob(b) => { + // Either lengthen the blob, or make at least one byte smaller and mutate the rest + let mut b = b.clone(); + if rng.random_bool(0.01) { + b.push(rng.random_range(0..=255)); + Value::Blob(b) + } else { + let index = rng.random_range(0..b.len()); + b[index] += 1; + // Mutate the rest of the blob + for val in b.iter_mut().skip(index + 1) { + *val = rng.random_range(0..=255); + } + Value::Blob(b) + } + } + _ => unreachable!(), + }; + Self(SimValue(new_value)) + } +} + +pub(crate) struct LikeValue(pub(crate) SimValue); + +impl ArbitraryFromMaybe<&SimValue> for LikeValue { + fn arbitrary_from_maybe(rng: &mut R, value: &SimValue) -> Option { + match &value.0 { + value @ Value::Text(..) => { + let t = value.to_string(); + let mut t = t.chars().collect::>(); + // Remove a number of characters, either insert `_` for each character removed, or + // insert one `%` for the whole substring + let mut i = 0; + while i < t.len() { + if rng.random_bool(0.1) { + t[i] = '_'; + } else if rng.random_bool(0.05) { + t[i] = '%'; + // skip a list of characters + for _ in 0..rng.random_range(0..=3.min(t.len() - i - 1)) { + t.remove(i + 1); + } + } + i += 1; + } + let index = rng.random_range(0..t.len()); + t.insert(index, '%'); + Some(Self(SimValue(Value::build_text( + t.into_iter().collect::(), + )))) + } + _ => None, + } + } +} diff --git a/sql_generation/lib.rs b/sql_generation/lib.rs index 8b1378917..f52cdebdf 100644 --- a/sql_generation/lib.rs +++ b/sql_generation/lib.rs @@ -1 +1,2 @@ - +pub mod generation; +pub mod model; diff --git a/sql_generation/model/mod.rs b/sql_generation/model/mod.rs new file mode 100644 index 000000000..e68355ee4 --- /dev/null +++ b/sql_generation/model/mod.rs @@ -0,0 +1,4 @@ +pub mod query; +pub mod table; + +pub(crate) const FAULT_ERROR_MSG: &str = "Injected fault"; diff --git a/sql_generation/model/query/create.rs b/sql_generation/model/query/create.rs new file mode 100644 index 000000000..ab0cd9789 --- /dev/null +++ b/sql_generation/model/query/create.rs @@ -0,0 +1,45 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::{ + generation::Shadow, + model::table::{SimValue, Table}, + runner::env::SimulatorTables, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Create { + pub(crate) table: Table, +} + +impl Shadow for Create { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if !tables.iter().any(|t| t.name == self.table.name) { + tables.push(self.table.clone()); + Ok(vec![]) + } else { + Err(anyhow::anyhow!( + "Table {} already exists. CREATE TABLE statement ignored.", + self.table.name + )) + } + } +} + +impl Display for Create { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CREATE TABLE {} (", self.table.name)?; + + for (i, column) in self.table.columns.iter().enumerate() { + if i != 0 { + write!(f, ",")?; + } + write!(f, "{} {}", column.name, column.column_type)?; + } + + write!(f, ")") + } +} diff --git a/sql_generation/model/query/create_index.rs b/sql_generation/model/query/create_index.rs new file mode 100644 index 000000000..cc7f7566a --- /dev/null +++ b/sql_generation/model/query/create_index.rs @@ -0,0 +1,106 @@ +use crate::{ + generation::{gen_random_text, pick, pick_n_unique, ArbitraryFrom, Shadow}, + model::table::SimValue, + runner::env::{SimulatorEnv, SimulatorTables}, +}; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum SortOrder { + Asc, + Desc, +} + +impl std::fmt::Display for SortOrder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SortOrder::Asc => write!(f, "ASC"), + SortOrder::Desc => write!(f, "DESC"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub(crate) struct CreateIndex { + pub(crate) index_name: String, + pub(crate) table_name: String, + pub(crate) columns: Vec<(String, SortOrder)>, +} + +impl Shadow for CreateIndex { + type Result = Vec>; + fn shadow(&self, env: &mut SimulatorTables) -> Vec> { + env.tables + .iter_mut() + .find(|t| t.name == self.table_name) + .unwrap() + .indexes + .push(self.index_name.clone()); + vec![] + } +} + +impl std::fmt::Display for CreateIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CREATE INDEX {} ON {} ({})", + self.index_name, + self.table_name, + self.columns + .iter() + .map(|(name, order)| format!("{name} {order}")) + .collect::>() + .join(", ") + ) + } +} + +impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { + fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { + assert!( + !env.tables.is_empty(), + "Cannot create an index when no tables exist in the environment." + ); + + let table = pick(&env.tables, rng); + + if table.columns.is_empty() { + panic!( + "Cannot create an index on table '{}' as it has no columns.", + table.name + ); + } + + let num_columns_to_pick = rng.random_range(1..=table.columns.len()); + let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng); + + let columns = picked_column_indices + .into_iter() + .map(|i| { + let column = &table.columns[i]; + ( + column.name.clone(), + if rng.random_bool(0.5) { + SortOrder::Asc + } else { + SortOrder::Desc + }, + ) + }) + .collect::>(); + + let index_name = format!( + "idx_{}_{}", + table.name, + gen_random_text(rng).chars().take(8).collect::() + ); + + CreateIndex { + index_name, + table_name: table.name.clone(), + columns, + } + } +} diff --git a/sql_generation/model/query/delete.rs b/sql_generation/model/query/delete.rs new file mode 100644 index 000000000..265cdfe96 --- /dev/null +++ b/sql_generation/model/query/delete.rs @@ -0,0 +1,41 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; + +use super::predicate::Predicate; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub(crate) struct Delete { + pub(crate) table: String, + pub(crate) predicate: Predicate, +} + +impl Shadow for Delete { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + let table = tables.tables.iter_mut().find(|t| t.name == self.table); + + if let Some(table) = table { + // If the table exists, we can delete from it + let t2 = table.clone(); + table.rows.retain_mut(|r| !self.predicate.test(r, &t2)); + } else { + // If the table does not exist, we return an error + return Err(anyhow::anyhow!( + "Table {} does not exist. DELETE statement ignored.", + self.table + )); + } + + Ok(vec![]) + } +} + +impl Display for Delete { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DELETE FROM {} WHERE {}", self.table, self.predicate) + } +} diff --git a/sql_generation/model/query/drop.rs b/sql_generation/model/query/drop.rs new file mode 100644 index 000000000..2b4379ff9 --- /dev/null +++ b/sql_generation/model/query/drop.rs @@ -0,0 +1,34 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub(crate) struct Drop { + pub(crate) table: String, +} + +impl Shadow for Drop { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if !tables.iter().any(|t| t.name == self.table) { + // If the table does not exist, we return an error + return Err(anyhow::anyhow!( + "Table {} does not exist. DROP statement ignored.", + self.table + )); + } + + tables.tables.retain(|t| t.name != self.table); + + Ok(vec![]) + } +} + +impl Display for Drop { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DROP TABLE {}", self.table) + } +} diff --git a/sql_generation/model/query/insert.rs b/sql_generation/model/query/insert.rs new file mode 100644 index 000000000..3dc8659df --- /dev/null +++ b/sql_generation/model/query/insert.rs @@ -0,0 +1,87 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; + +use super::select::Select; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub(crate) enum Insert { + Values { + table: String, + values: Vec>, + }, + Select { + table: String, + select: Box
to support resolving a value from another table +// This function attempts to convert an simpler easily computable expression into values +// TODO: In the future, we can try to expand this computation if we want to support harder properties that require us +// to already know more values before hand +pub fn expr_to_value( + expr: &ast::Expr, + row: &[SimValue], + table: &T, +) -> Option { + match expr { + ast::Expr::DoublyQualified(_, _, ast::Name::Ident(col_name)) + | ast::Expr::DoublyQualified(_, _, ast::Name::Quoted(col_name)) + | ast::Expr::Qualified(_, ast::Name::Ident(col_name)) + | ast::Expr::Qualified(_, ast::Name::Quoted(col_name)) + | ast::Expr::Id(ast::Name::Ident(col_name)) => { + let columns = table.columns().collect::>(); + assert_eq!(row.len(), columns.len()); + columns + .iter() + .zip(row.iter()) + .find(|(column, _)| column.column.name == *col_name) + .map(|(_, value)| value) + .cloned() + } + ast::Expr::Literal(literal) => Some(literal.into()), + ast::Expr::Binary(lhs, op, rhs) => { + let lhs = expr_to_value(lhs, row, table)?; + let rhs = expr_to_value(rhs, row, table)?; + Some(lhs.binary_compare(&rhs, *op)) + } + ast::Expr::Like { + lhs, + not, + op, + rhs, + escape: _, // TODO: support escape + } => { + let lhs = expr_to_value(lhs, row, table)?; + let rhs = expr_to_value(rhs, row, table)?; + let res = lhs.like_compare(&rhs, *op); + let value: SimValue = if *not { !res } else { res }.into(); + Some(value) + } + ast::Expr::Unary(op, expr) => { + let value = expr_to_value(expr, row, table)?; + Some(value.unary_exec(*op)) + } + ast::Expr::Parenthesized(exprs) => { + assert_eq!(exprs.len(), 1); + expr_to_value(&exprs[0], row, table) + } + _ => unreachable!("{:?}", expr), + } +} + +impl Display for Predicate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.to_fmt(f) + } +} diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs new file mode 100644 index 000000000..b5e516a0e --- /dev/null +++ b/sql_generation/model/query/select.rs @@ -0,0 +1,496 @@ +use std::{collections::HashSet, fmt::Display}; + +use anyhow::Context; +pub use ast::Distinctness; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use turso_parser::ast::{self, fmt::ToTokens, SortOrder}; + +use crate::{ + generation::Shadow, + model::{ + query::EmptyContext, + table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}, + }, + runner::env::SimulatorTables, +}; + +use super::predicate::Predicate; + +/// `SELECT` or `RETURNING` result column +// https://sqlite.org/syntax/result-column.html +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum ResultColumn { + /// expression + Expr(Predicate), + /// `*` + Star, + /// column name + Column(String), +} + +impl Display for ResultColumn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResultColumn::Expr(expr) => write!(f, "({expr})"), + ResultColumn::Star => write!(f, "*"), + ResultColumn::Column(name) => write!(f, "{name}"), + } + } +} +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub(crate) struct Select { + pub(crate) body: SelectBody, + pub(crate) limit: Option, +} + +impl Select { + pub fn simple(table: String, where_clause: Predicate) -> Self { + Self::single( + table, + vec![ResultColumn::Star], + where_clause, + None, + Distinctness::All, + ) + } + + pub fn expr(expr: Predicate) -> Self { + Select { + body: SelectBody { + select: Box::new(SelectInner { + distinctness: Distinctness::All, + columns: vec![ResultColumn::Expr(expr)], + from: None, + where_clause: Predicate::true_(), + order_by: None, + }), + compounds: Vec::new(), + }, + limit: None, + } + } + + pub fn single( + table: String, + result_columns: Vec, + where_clause: Predicate, + limit: Option, + distinct: Distinctness, + ) -> Self { + Select { + body: SelectBody { + select: Box::new(SelectInner { + distinctness: distinct, + columns: result_columns, + from: Some(FromClause { + table, + joins: Vec::new(), + }), + where_clause, + order_by: None, + }), + compounds: Vec::new(), + }, + limit, + } + } + + pub fn compound(left: Select, right: Select, operator: CompoundOperator) -> Self { + let mut body = left.body; + body.compounds.push(CompoundSelect { + operator, + select: Box::new(right.body.select.as_ref().clone()), + }); + Select { + body, + limit: left.limit.or(right.limit), + } + } + + pub(crate) fn dependencies(&self) -> HashSet { + if self.body.select.from.is_none() { + return HashSet::new(); + } + let from = self.body.select.from.as_ref().unwrap(); + let mut tables = HashSet::new(); + tables.insert(from.table.clone()); + + tables.extend(from.dependencies()); + + for compound in &self.body.compounds { + tables.extend( + compound + .select + .from + .as_ref() + .map(|f| f.dependencies()) + .unwrap_or(vec![]) + .into_iter(), + ); + } + + tables + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SelectBody { + /// first select + pub select: Box, + /// compounds + pub compounds: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct OrderBy { + pub columns: Vec<(String, SortOrder)>, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SelectInner { + /// `DISTINCT` + pub distinctness: Distinctness, + /// columns + pub columns: Vec, + /// `FROM` clause + pub from: Option, + /// `WHERE` clause + pub where_clause: Predicate, + /// `ORDER BY` clause + pub order_by: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum CompoundOperator { + /// `UNION` + Union, + /// `UNION ALL` + UnionAll, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CompoundSelect { + /// operator + pub operator: CompoundOperator, + /// select + pub select: Box, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct FromClause { + /// table + pub table: String, + /// `JOIN`ed tables + pub joins: Vec, +} + +impl FromClause { + fn to_sql_ast(&self) -> ast::FromClause { + ast::FromClause { + select: Some(Box::new(ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::from_str(&self.table)), + None, + None, + ))), + joins: if self.joins.is_empty() { + None + } else { + Some( + self.joins + .iter() + .map(|join| ast::JoinedSelectTable { + operator: match join.join_type { + JoinType::Inner => { + ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER)) + } + JoinType::Left => { + ast::JoinOperator::TypedJoin(Some(ast::JoinType::LEFT)) + } + JoinType::Right => { + ast::JoinOperator::TypedJoin(Some(ast::JoinType::RIGHT)) + } + JoinType::Full => { + ast::JoinOperator::TypedJoin(Some(ast::JoinType::OUTER)) + } + JoinType::Cross => { + ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)) + } + }, + table: ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::from_str(&join.table)), + None, + None, + ), + constraint: Some(ast::JoinConstraint::On(join.on.0.clone())), + }) + .collect(), + ) + }, + } + } + + pub(crate) fn dependencies(&self) -> Vec { + let mut deps = vec![self.table.clone()]; + for join in &self.joins { + deps.push(join.table.clone()); + } + deps + } +} + +impl Shadow for FromClause { + type Result = anyhow::Result; + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + let tables = &mut env.tables; + + let first_table = tables + .iter() + .find(|t| t.name == self.table) + .context("Table not found")?; + + let mut join_table = JoinTable { + tables: vec![first_table.clone()], + rows: Vec::new(), + }; + + for join in &self.joins { + let joined_table = tables + .iter() + .find(|t| t.name == join.table) + .context("Joined table not found")?; + + join_table.tables.push(joined_table.clone()); + + match join.join_type { + JoinType::Inner => { + // Implement inner join logic + let join_rows = joined_table + .rows + .iter() + .filter(|row| join.on.test(row, joined_table)) + .cloned() + .collect::>(); + // take a cartesian product of the rows + let all_row_pairs = join_table + .rows + .clone() + .into_iter() + .cartesian_product(join_rows.iter()); + + for (row1, row2) in all_row_pairs { + let row = row1.iter().chain(row2.iter()).cloned().collect::>(); + + let is_in = join.on.test(&row, &join_table); + + if is_in { + join_table.rows.push(row); + } + } + } + _ => todo!(), + } + } + Ok(join_table) + } +} + +impl Shadow for SelectInner { + type Result = anyhow::Result; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + if let Some(from) = &self.from { + let mut join_table = from.shadow(env)?; + let col_count = join_table.columns().count(); + for row in &mut join_table.rows { + assert_eq!( + row.len(), + col_count, + "Row length does not match column length after join" + ); + } + let join_clone = join_table.clone(); + + join_table + .rows + .retain(|row| self.where_clause.test(row, &join_clone)); + + if self.distinctness == Distinctness::Distinct { + join_table.rows.sort_unstable(); + join_table.rows.dedup(); + } + + Ok(join_table) + } else { + assert!(self + .columns + .iter() + .all(|col| matches!(col, ResultColumn::Expr(_)))); + + // If `WHERE` is false, just return an empty table + if !self.where_clause.test(&[], &Table::anonymous(vec![])) { + return Ok(JoinTable { + tables: Vec::new(), + rows: Vec::new(), + }); + } + + // Compute the results of the column expressions and make a row + let mut row = Vec::new(); + for col in &self.columns { + match col { + ResultColumn::Expr(expr) => { + let value = expr.eval(&[], &Table::anonymous(vec![])); + if let Some(value) = value { + row.push(value); + } else { + return Err(anyhow::anyhow!( + "Failed to evaluate expression in free select ({})", + expr.0.format_with_context(&EmptyContext {}).unwrap() + )); + } + } + _ => unreachable!("Only expressions are allowed in free selects"), + } + } + + Ok(JoinTable { + tables: Vec::new(), + rows: vec![row], + }) + } + } +} + +impl Shadow for Select { + type Result = anyhow::Result>>; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + let first_result = self.body.select.shadow(env)?; + + let mut rows = first_result.rows; + + for compound in self.body.compounds.iter() { + let compound_results = compound.select.shadow(env)?; + + match compound.operator { + CompoundOperator::Union => { + // Union means we need to combine the results, removing duplicates + let mut new_rows = compound_results.rows; + new_rows.extend(rows.clone()); + new_rows.sort_unstable(); + new_rows.dedup(); + rows = new_rows; + } + CompoundOperator::UnionAll => { + // Union all means we just concatenate the results + rows.extend(compound_results.rows.into_iter()); + } + } + } + + Ok(rows) + } +} + +impl Select { + pub fn to_sql_ast(&self) -> ast::Select { + ast::Select { + with: None, + body: ast::SelectBody { + select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { + distinctness: if self.body.select.distinctness == Distinctness::Distinct { + Some(ast::Distinctness::Distinct) + } else { + None + }, + columns: self + .body + .select + .columns + .iter() + .map(|col| match col { + ResultColumn::Expr(expr) => { + ast::ResultColumn::Expr(expr.0.clone(), None) + } + ResultColumn::Star => ast::ResultColumn::Star, + ResultColumn::Column(name) => ast::ResultColumn::Expr( + ast::Expr::Id(ast::Name::Ident(name.clone())), + None, + ), + }) + .collect(), + from: self.body.select.from.as_ref().map(|f| f.to_sql_ast()), + where_clause: Some(self.body.select.where_clause.0.clone()), + group_by: None, + window_clause: None, + }))), + compounds: Some( + self.body + .compounds + .iter() + .map(|compound| ast::CompoundSelect { + operator: match compound.operator { + CompoundOperator::Union => ast::CompoundOperator::Union, + CompoundOperator::UnionAll => ast::CompoundOperator::UnionAll, + }, + select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { + distinctness: Some(compound.select.distinctness), + columns: compound + .select + .columns + .iter() + .map(|col| match col { + ResultColumn::Expr(expr) => { + ast::ResultColumn::Expr(expr.0.clone(), None) + } + ResultColumn::Star => ast::ResultColumn::Star, + ResultColumn::Column(name) => ast::ResultColumn::Expr( + ast::Expr::Id(ast::Name::Ident(name.clone())), + None, + ), + }) + .collect(), + from: compound.select.from.as_ref().map(|f| f.to_sql_ast()), + where_clause: Some(compound.select.where_clause.0.clone()), + group_by: None, + window_clause: None, + }))), + }) + .collect(), + ), + }, + order_by: self.body.select.order_by.as_ref().map(|o| { + o.columns + .iter() + .map(|(name, order)| ast::SortedColumn { + expr: ast::Expr::Id(ast::Name::Ident(name.clone())), + order: match order { + SortOrder::Asc => Some(ast::SortOrder::Asc), + SortOrder::Desc => Some(ast::SortOrder::Desc), + }, + nulls: None, + }) + .collect() + }), + limit: self.limit.map(|l| { + Box::new(ast::Limit { + expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())), + offset: None, + }) + }), + } + } +} +impl Display for Select { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_sql_ast().to_fmt_with_context(f, &EmptyContext {}) + } +} + +#[cfg(test)] +mod select_tests { + + #[test] + fn test_select_display() {} +} diff --git a/sql_generation/model/query/transaction.rs b/sql_generation/model/query/transaction.rs new file mode 100644 index 000000000..a73fb076e --- /dev/null +++ b/sql_generation/model/query/transaction.rs @@ -0,0 +1,60 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Begin { + pub(crate) immediate: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Commit; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Rollback; + +impl Shadow for Begin { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + tables.snapshot = Some(tables.tables.clone()); + vec![] + } +} + +impl Shadow for Commit { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + tables.snapshot = None; + vec![] + } +} + +impl Shadow for Rollback { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if let Some(tables_) = tables.snapshot.take() { + tables.tables = tables_; + } + vec![] + } +} + +impl Display for Begin { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BEGIN {}", if self.immediate { "IMMEDIATE" } else { "" }) + } +} + +impl Display for Commit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "COMMIT") + } +} + +impl Display for Rollback { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ROLLBACK") + } +} diff --git a/sql_generation/model/query/update.rs b/sql_generation/model/query/update.rs new file mode 100644 index 000000000..a4cc13fa8 --- /dev/null +++ b/sql_generation/model/query/update.rs @@ -0,0 +1,71 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; + +use super::predicate::Predicate; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub(crate) struct Update { + pub(crate) table: String, + pub(crate) set_values: Vec<(String, SimValue)>, // Pair of value for set expressions => SET name=value + pub(crate) predicate: Predicate, +} + +impl Update { + pub fn table(&self) -> &str { + &self.table + } +} + +impl Shadow for Update { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + let table = tables.tables.iter_mut().find(|t| t.name == self.table); + + let table = if let Some(table) = table { + table + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. UPDATE statement ignored.", + self.table + )); + }; + + let t2 = table.clone(); + for row in table + .rows + .iter_mut() + .filter(|r| self.predicate.test(r, &t2)) + { + for (column, set_value) in &self.set_values { + if let Some((idx, _)) = table + .columns + .iter() + .enumerate() + .find(|(_, c)| &c.name == column) + { + row[idx] = set_value.clone(); + } + } + } + + Ok(vec![]) + } +} + +impl Display for Update { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "UPDATE {} SET ", self.table)?; + for (i, (name, value)) in self.set_values.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{name} = {value}")?; + } + write!(f, " WHERE {}", self.predicate)?; + Ok(()) + } +} diff --git a/sql_generation/model/table.rs b/sql_generation/model/table.rs new file mode 100644 index 000000000..210039e17 --- /dev/null +++ b/sql_generation/model/table.rs @@ -0,0 +1,428 @@ +use std::{fmt::Display, hash::Hash, ops::Deref}; + +use serde::{Deserialize, Serialize}; +use turso_core::{numeric::Numeric, types}; +use turso_parser::ast; + +use crate::model::query::predicate::Predicate; + +pub(crate) struct Name(pub(crate) String); + +impl Deref for Name { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ContextColumn<'a> { + pub table_name: &'a str, + pub column: &'a Column, +} + +pub trait TableContext { + fn columns<'a>(&'a self) -> impl Iterator>; + fn rows(&self) -> &Vec>; +} + +impl TableContext for Table { + fn columns<'a>(&'a self) -> impl Iterator> { + self.columns.iter().map(|col| ContextColumn { + column: col, + table_name: &self.name, + }) + } + + fn rows(&self) -> &Vec> { + &self.rows + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Table { + pub(crate) name: String, + pub(crate) columns: Vec, + pub(crate) rows: Vec>, + pub(crate) indexes: Vec, +} + +impl Table { + pub fn anonymous(rows: Vec>) -> Self { + Self { + rows, + name: "".to_string(), + columns: vec![], + indexes: vec![], + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Column { + pub(crate) name: String, + pub(crate) column_type: ColumnType, + pub(crate) primary: bool, + pub(crate) unique: bool, +} + +// Uniquely defined by name in this case +impl Hash for Column { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} + +impl PartialEq for Column { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} + +impl Eq for Column {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum ColumnType { + Integer, + Float, + Text, + Blob, +} + +impl Display for ColumnType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Integer => write!(f, "INTEGER"), + Self::Float => write!(f, "REAL"), + Self::Text => write!(f, "TEXT"), + Self::Blob => write!(f, "BLOB"), + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct JoinedTable { + /// table name + pub table: String, + /// `JOIN` type + pub join_type: JoinType, + /// `ON` clause + pub on: Predicate, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, +} + +impl TableContext for JoinTable { + fn columns<'a>(&'a self) -> impl Iterator> { + self.tables.iter().flat_map(|table| table.columns()) + } + + fn rows(&self) -> &Vec> { + &self.rows + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct JoinTable { + pub tables: Vec
, + pub rows: Vec>, +} + +fn float_to_string(float: &f64, serializer: S) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str(&format!("{float}")) +} + +fn string_to_float<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + s.parse().map_err(serde::de::Error::custom) +} + +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)] +pub(crate) struct SimValue(pub turso_core::Value); + +fn to_sqlite_blob(bytes: &[u8]) -> String { + format!( + "X'{}'", + bytes + .iter() + .fold(String::new(), |acc, b| acc + &format!("{b:02X}")) + ) +} + +impl Display for SimValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + types::Value::Null => write!(f, "NULL"), + types::Value::Integer(i) => write!(f, "{i}"), + types::Value::Float(fl) => write!(f, "{fl}"), + value @ types::Value::Text(..) => write!(f, "'{value}'"), + types::Value::Blob(b) => write!(f, "{}", to_sqlite_blob(b)), + } + } +} + +impl SimValue { + pub const FALSE: Self = SimValue(types::Value::Integer(0)); + pub const TRUE: Self = SimValue(types::Value::Integer(1)); + + pub fn as_bool(&self) -> bool { + Numeric::from(&self.0).try_into_bool().unwrap_or_default() + } + + // TODO: support more predicates + /// Returns a Result of a Binary Operation + /// + /// TODO: forget collations for now + /// TODO: have the [ast::Operator::Equals], [ast::Operator::NotEquals], [ast::Operator::Greater], + /// [ast::Operator::GreaterEquals], [ast::Operator::Less], [ast::Operator::LessEquals] function to be extracted + /// into its functions in turso_core so that it can be used here + pub fn binary_compare(&self, other: &Self, operator: ast::Operator) -> SimValue { + match operator { + ast::Operator::Add => self.0.exec_add(&other.0).into(), + ast::Operator::And => self.0.exec_and(&other.0).into(), + ast::Operator::ArrowRight => todo!(), + ast::Operator::ArrowRightShift => todo!(), + ast::Operator::BitwiseAnd => self.0.exec_bit_and(&other.0).into(), + ast::Operator::BitwiseOr => self.0.exec_bit_or(&other.0).into(), + ast::Operator::BitwiseNot => todo!(), // TODO: Do not see any function usage of this operator in Core + ast::Operator::Concat => self.0.exec_concat(&other.0).into(), + ast::Operator::Equals => (self == other).into(), + ast::Operator::Divide => self.0.exec_divide(&other.0).into(), + ast::Operator::Greater => (self > other).into(), + ast::Operator::GreaterEquals => (self >= other).into(), + // TODO: Test these implementations + ast::Operator::Is => match (&self.0, &other.0) { + (types::Value::Null, types::Value::Null) => true.into(), + (types::Value::Null, _) => false.into(), + (_, types::Value::Null) => false.into(), + _ => self.binary_compare(other, ast::Operator::Equals), + }, + ast::Operator::IsNot => self + .binary_compare(other, ast::Operator::Is) + .unary_exec(ast::UnaryOperator::Not), + ast::Operator::LeftShift => self.0.exec_shift_left(&other.0).into(), + ast::Operator::Less => (self < other).into(), + ast::Operator::LessEquals => (self <= other).into(), + ast::Operator::Modulus => self.0.exec_remainder(&other.0).into(), + ast::Operator::Multiply => self.0.exec_multiply(&other.0).into(), + ast::Operator::NotEquals => (self != other).into(), + ast::Operator::Or => self.0.exec_or(&other.0).into(), + ast::Operator::RightShift => self.0.exec_shift_right(&other.0).into(), + ast::Operator::Subtract => self.0.exec_subtract(&other.0).into(), + } + } + + // TODO: support more operators. Copy the implementation for exec_glob + pub fn like_compare(&self, other: &Self, operator: ast::LikeOperator) -> bool { + match operator { + ast::LikeOperator::Glob => todo!(), + ast::LikeOperator::Like => { + // TODO: support ESCAPE `expr` option in AST + // TODO: regex cache + types::Value::exec_like( + None, + other.0.to_string().as_str(), + self.0.to_string().as_str(), + ) + } + ast::LikeOperator::Match => todo!(), + ast::LikeOperator::Regexp => todo!(), + } + } + + pub fn unary_exec(&self, operator: ast::UnaryOperator) -> SimValue { + let new_value = match operator { + ast::UnaryOperator::BitwiseNot => self.0.exec_bit_not(), + ast::UnaryOperator::Negative => { + SimValue(types::Value::Integer(0)) + .binary_compare(self, ast::Operator::Subtract) + .0 + } + ast::UnaryOperator::Not => self.0.exec_boolean_not(), + ast::UnaryOperator::Positive => self.0.clone(), + }; + Self(new_value) + } +} + +impl From for SimValue { + fn from(value: ast::Literal) -> Self { + Self::from(&value) + } +} + +/// Converts a SQL string literal with already-escaped single quotes to a regular string by: +/// - Removing the enclosing single quotes +/// - Converting sequences of 2N single quotes ('''''') to N single quotes (''') +/// +/// Assumes: +/// - The input starts and ends with a single quote +/// - The input contains a valid amount of single quotes inside the enclosing quotes; +/// i.e. any ' is escaped as a double '' +fn unescape_singlequotes(input: &str) -> String { + assert!( + input.starts_with('\'') && input.ends_with('\''), + "Input string must be wrapped in single quotes" + ); + // Skip first and last characters (the enclosing quotes) + let inner = &input[1..input.len() - 1]; + + let mut result = String::with_capacity(inner.len()); + let mut chars = inner.chars().peekable(); + + while let Some(c) = chars.next() { + if c == '\'' { + // Count consecutive single quotes + let mut quote_count = 1; + while chars.peek() == Some(&'\'') { + quote_count += 1; + chars.next(); + } + assert!( + quote_count % 2 == 0, + "Expected even number of quotes, got {quote_count} in string {input}" + ); + // For every pair of quotes, output one quote + for _ in 0..(quote_count / 2) { + result.push('\''); + } + } else { + result.push(c); + } + } + + result +} + +/// Escapes a string by doubling contained single quotes and then wrapping it in single quotes. +fn escape_singlequotes(input: &str) -> String { + let mut result = String::with_capacity(input.len() + 2); + result.push('\''); + result.push_str(&input.replace("'", "''")); + result.push('\''); + result +} + +impl From<&ast::Literal> for SimValue { + fn from(value: &ast::Literal) -> Self { + let new_value = match value { + ast::Literal::Null => types::Value::Null, + ast::Literal::Numeric(number) => Numeric::from(number).into(), + ast::Literal::String(string) => types::Value::build_text(unescape_singlequotes(string)), + ast::Literal::Blob(blob) => types::Value::Blob( + blob.as_bytes() + .chunks_exact(2) + .map(|pair| { + // We assume that sqlite3-parser has already validated that + // the input is valid hex string, thus unwrap is safe. + let hex_byte = std::str::from_utf8(pair).unwrap(); + u8::from_str_radix(hex_byte, 16).unwrap() + }) + .collect(), + ), + ast::Literal::Keyword(keyword) => match keyword.to_uppercase().as_str() { + "TRUE" => types::Value::Integer(1), + "FALSE" => types::Value::Integer(0), + "NULL" => types::Value::Null, + _ => unimplemented!("Unsupported keyword literal: {}", keyword), + }, + lit => unimplemented!("{:?}", lit), + }; + Self(new_value) + } +} + +impl From for ast::Literal { + fn from(value: SimValue) -> Self { + Self::from(&value) + } +} + +impl From<&SimValue> for ast::Literal { + fn from(value: &SimValue) -> Self { + match &value.0 { + types::Value::Null => Self::Null, + types::Value::Integer(i) => Self::Numeric(i.to_string()), + types::Value::Float(f) => Self::Numeric(f.to_string()), + text @ types::Value::Text(..) => Self::String(escape_singlequotes(&text.to_string())), + types::Value::Blob(blob) => Self::Blob(hex::encode(blob)), + } + } +} + +impl From for SimValue { + fn from(value: bool) -> Self { + if value { + SimValue::TRUE + } else { + SimValue::FALSE + } + } +} + +impl From for turso_core::types::Value { + fn from(value: SimValue) -> Self { + value.0 + } +} + +impl From<&SimValue> for turso_core::types::Value { + fn from(value: &SimValue) -> Self { + value.0.clone() + } +} + +impl From for SimValue { + fn from(value: turso_core::types::Value) -> Self { + Self(value) + } +} + +impl From<&turso_core::types::Value> for SimValue { + fn from(value: &turso_core::types::Value) -> Self { + Self(value.clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::model::table::{escape_singlequotes, unescape_singlequotes}; + + #[test] + fn test_unescape_singlequotes() { + assert_eq!(unescape_singlequotes("'hello'"), "hello"); + assert_eq!(unescape_singlequotes("'O''Reilly'"), "O'Reilly"); + assert_eq!( + unescape_singlequotes("'multiple''single''quotes'"), + "multiple'single'quotes" + ); + assert_eq!(unescape_singlequotes("'test''''test'"), "test''test"); + assert_eq!(unescape_singlequotes("'many''''''quotes'"), "many'''quotes"); + } + + #[test] + fn test_escape_singlequotes() { + assert_eq!(escape_singlequotes("hello"), "'hello'"); + assert_eq!(escape_singlequotes("O'Reilly"), "'O''Reilly'"); + assert_eq!( + escape_singlequotes("multiple'single'quotes"), + "'multiple''single''quotes'" + ); + assert_eq!(escape_singlequotes("test''test"), "'test''''test'"); + assert_eq!(escape_singlequotes("many'''quotes"), "'many''''''quotes'"); + } +} From 642060f2837cca01d4773fe02e375391cfbe40f6 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 15:53:45 -0300 Subject: [PATCH 60/73] refactor sql_generation/model/query --- parser/src/ast.rs | 35 +++ sql_generation/model/query/create.rs | 22 +- sql_generation/model/query/create_index.rs | 67 ---- sql_generation/model/query/delete.rs | 24 -- sql_generation/model/query/drop.rs | 20 -- sql_generation/model/query/insert.rs | 34 +-- sql_generation/model/query/mod.rs | 28 +- sql_generation/model/query/select.rs | 339 +++++---------------- sql_generation/model/query/transaction.rs | 28 -- sql_generation/model/query/update.rs | 39 +-- 10 files changed, 123 insertions(+), 513 deletions(-) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 5626ffbaa..44a427dcb 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -982,6 +982,41 @@ impl std::fmt::Display for QualifiedName { } } +impl QualifiedName { + /// Constructor + pub fn single(name: Name) -> Self { + Self { + db_name: None, + name, + alias: None, + } + } + /// Constructor + pub fn fullname(db_name: Name, name: Name) -> Self { + Self { + db_name: Some(db_name), + name, + alias: None, + } + } + /// Constructor + pub fn xfullname(db_name: Name, name: Name, alias: Name) -> Self { + Self { + db_name: Some(db_name), + name, + alias: Some(alias), + } + } + /// Constructor + pub fn alias(name: Name, alias: Name) -> Self { + Self { + db_name: None, + name, + alias: Some(alias), + } + } +} + /// `ALTER TABLE` body // https://sqlite.org/lang_altertable.html #[derive(Clone, Debug, PartialEq, Eq)] diff --git a/sql_generation/model/query/create.rs b/sql_generation/model/query/create.rs index ab0cd9789..e628b5cc5 100644 --- a/sql_generation/model/query/create.rs +++ b/sql_generation/model/query/create.rs @@ -2,33 +2,13 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{ - generation::Shadow, - model::table::{SimValue, Table}, - runner::env::SimulatorTables, -}; +use crate::model::table::Table; #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct Create { pub(crate) table: Table, } -impl Shadow for Create { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if !tables.iter().any(|t| t.name == self.table.name) { - tables.push(self.table.clone()); - Ok(vec![]) - } else { - Err(anyhow::anyhow!( - "Table {} already exists. CREATE TABLE statement ignored.", - self.table.name - )) - } - } -} - impl Display for Create { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "CREATE TABLE {} (", self.table.name)?; diff --git a/sql_generation/model/query/create_index.rs b/sql_generation/model/query/create_index.rs index cc7f7566a..aba0f98bf 100644 --- a/sql_generation/model/query/create_index.rs +++ b/sql_generation/model/query/create_index.rs @@ -1,9 +1,3 @@ -use crate::{ - generation::{gen_random_text, pick, pick_n_unique, ArbitraryFrom, Shadow}, - model::table::SimValue, - runner::env::{SimulatorEnv, SimulatorTables}, -}; -use rand::Rng; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -28,19 +22,6 @@ pub(crate) struct CreateIndex { pub(crate) columns: Vec<(String, SortOrder)>, } -impl Shadow for CreateIndex { - type Result = Vec>; - fn shadow(&self, env: &mut SimulatorTables) -> Vec> { - env.tables - .iter_mut() - .find(|t| t.name == self.table_name) - .unwrap() - .indexes - .push(self.index_name.clone()); - vec![] - } -} - impl std::fmt::Display for CreateIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -56,51 +37,3 @@ impl std::fmt::Display for CreateIndex { ) } } - -impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - assert!( - !env.tables.is_empty(), - "Cannot create an index when no tables exist in the environment." - ); - - let table = pick(&env.tables, rng); - - if table.columns.is_empty() { - panic!( - "Cannot create an index on table '{}' as it has no columns.", - table.name - ); - } - - let num_columns_to_pick = rng.random_range(1..=table.columns.len()); - let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng); - - let columns = picked_column_indices - .into_iter() - .map(|i| { - let column = &table.columns[i]; - ( - column.name.clone(), - if rng.random_bool(0.5) { - SortOrder::Asc - } else { - SortOrder::Desc - }, - ) - }) - .collect::>(); - - let index_name = format!( - "idx_{}_{}", - table.name, - gen_random_text(rng).chars().take(8).collect::() - ); - - CreateIndex { - index_name, - table_name: table.name.clone(), - columns, - } - } -} diff --git a/sql_generation/model/query/delete.rs b/sql_generation/model/query/delete.rs index 265cdfe96..a86479850 100644 --- a/sql_generation/model/query/delete.rs +++ b/sql_generation/model/query/delete.rs @@ -2,8 +2,6 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - use super::predicate::Predicate; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -12,28 +10,6 @@ pub(crate) struct Delete { pub(crate) predicate: Predicate, } -impl Shadow for Delete { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - let table = tables.tables.iter_mut().find(|t| t.name == self.table); - - if let Some(table) = table { - // If the table exists, we can delete from it - let t2 = table.clone(); - table.rows.retain_mut(|r| !self.predicate.test(r, &t2)); - } else { - // If the table does not exist, we return an error - return Err(anyhow::anyhow!( - "Table {} does not exist. DELETE statement ignored.", - self.table - )); - } - - Ok(vec![]) - } -} - impl Display for Delete { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "DELETE FROM {} WHERE {}", self.table, self.predicate) diff --git a/sql_generation/model/query/drop.rs b/sql_generation/model/query/drop.rs index 2b4379ff9..d9a34a9e9 100644 --- a/sql_generation/model/query/drop.rs +++ b/sql_generation/model/query/drop.rs @@ -2,31 +2,11 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub(crate) struct Drop { pub(crate) table: String, } -impl Shadow for Drop { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if !tables.iter().any(|t| t.name == self.table) { - // If the table does not exist, we return an error - return Err(anyhow::anyhow!( - "Table {} does not exist. DROP statement ignored.", - self.table - )); - } - - tables.tables.retain(|t| t.name != self.table); - - Ok(vec![]) - } -} - impl Display for Drop { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "DROP TABLE {}", self.table) diff --git a/sql_generation/model/query/insert.rs b/sql_generation/model/query/insert.rs index 3dc8659df..9fd391612 100644 --- a/sql_generation/model/query/insert.rs +++ b/sql_generation/model/query/insert.rs @@ -2,7 +2,7 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; +use crate::model::table::SimValue; use super::select::Select; @@ -18,38 +18,6 @@ pub(crate) enum Insert { }, } -impl Shadow for Insert { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - match self { - Insert::Values { table, values } => { - if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { - t.rows.extend(values.clone()); - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. INSERT statement ignored.", - table - )); - } - } - Insert::Select { table, select } => { - let rows = select.shadow(tables)?; - if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { - t.rows.extend(rows); - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. INSERT statement ignored.", - table - )); - } - } - } - - Ok(vec![]) - } -} - impl Insert { pub(crate) fn table(&self) -> &str { match self { diff --git a/sql_generation/model/query/mod.rs b/sql_generation/model/query/mod.rs index 2c8704003..9ae222a9a 100644 --- a/sql_generation/model/query/mod.rs +++ b/sql_generation/model/query/mod.rs @@ -10,14 +10,7 @@ use serde::{Deserialize, Serialize}; use turso_parser::ast::fmt::ToSqlContext; use update::Update; -use crate::{ - generation::Shadow, - model::{ - query::transaction::{Begin, Commit, Rollback}, - table::SimValue, - }, - runner::env::SimulatorTables, -}; +use crate::model::query::transaction::{Begin, Commit, Rollback}; pub mod create; pub mod create_index; @@ -75,25 +68,6 @@ impl Query { } } -impl Shadow for Query { - type Result = anyhow::Result>>; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - match self { - Query::Create(create) => create.shadow(env), - Query::Insert(insert) => insert.shadow(env), - Query::Delete(delete) => delete.shadow(env), - Query::Select(select) => select.shadow(env), - Query::Update(update) => update.shadow(env), - Query::Drop(drop) => drop.shadow(env), - Query::CreateIndex(create_index) => Ok(create_index.shadow(env)), - Query::Begin(begin) => Ok(begin.shadow(env)), - Query::Commit(commit) => Ok(commit.shadow(env)), - Query::Rollback(rollback) => Ok(rollback.shadow(env)), - } - } -} - impl Display for Query { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs index b5e516a0e..6c7897e1d 100644 --- a/sql_generation/model/query/select.rs +++ b/sql_generation/model/query/select.rs @@ -1,18 +1,12 @@ use std::{collections::HashSet, fmt::Display}; -use anyhow::Context; pub use ast::Distinctness; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use turso_parser::ast::{self, fmt::ToTokens, SortOrder}; -use crate::{ - generation::Shadow, - model::{ - query::EmptyContext, - table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}, - }, - runner::env::SimulatorTables, +use crate::model::{ + query::EmptyContext, + table::{JoinType, JoinedTable}, }; use super::predicate::Predicate; @@ -188,45 +182,30 @@ pub struct FromClause { impl FromClause { fn to_sql_ast(&self) -> ast::FromClause { ast::FromClause { - select: Some(Box::new(ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_str(&self.table)), + select: Box::new(ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::new(&self.table)), None, None, - ))), - joins: if self.joins.is_empty() { - None - } else { - Some( - self.joins - .iter() - .map(|join| ast::JoinedSelectTable { - operator: match join.join_type { - JoinType::Inner => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER)) - } - JoinType::Left => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::LEFT)) - } - JoinType::Right => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::RIGHT)) - } - JoinType::Full => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::OUTER)) - } - JoinType::Cross => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)) - } - }, - table: ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_str(&join.table)), - None, - None, - ), - constraint: Some(ast::JoinConstraint::On(join.on.0.clone())), - }) - .collect(), - ) - }, + )), + joins: self + .joins + .iter() + .map(|join| ast::JoinedSelectTable { + operator: match join.join_type { + JoinType::Inner => ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER)), + JoinType::Left => ast::JoinOperator::TypedJoin(Some(ast::JoinType::LEFT)), + JoinType::Right => ast::JoinOperator::TypedJoin(Some(ast::JoinType::RIGHT)), + JoinType::Full => ast::JoinOperator::TypedJoin(Some(ast::JoinType::OUTER)), + JoinType::Cross => ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)), + }, + table: Box::new(ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::new(&join.table)), + None, + None, + )), + constraint: Some(ast::JoinConstraint::On(Box::new(join.on.0.clone()))), + }) + .collect(), } } @@ -239,166 +218,12 @@ impl FromClause { } } -impl Shadow for FromClause { - type Result = anyhow::Result; - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - let tables = &mut env.tables; - - let first_table = tables - .iter() - .find(|t| t.name == self.table) - .context("Table not found")?; - - let mut join_table = JoinTable { - tables: vec![first_table.clone()], - rows: Vec::new(), - }; - - for join in &self.joins { - let joined_table = tables - .iter() - .find(|t| t.name == join.table) - .context("Joined table not found")?; - - join_table.tables.push(joined_table.clone()); - - match join.join_type { - JoinType::Inner => { - // Implement inner join logic - let join_rows = joined_table - .rows - .iter() - .filter(|row| join.on.test(row, joined_table)) - .cloned() - .collect::>(); - // take a cartesian product of the rows - let all_row_pairs = join_table - .rows - .clone() - .into_iter() - .cartesian_product(join_rows.iter()); - - for (row1, row2) in all_row_pairs { - let row = row1.iter().chain(row2.iter()).cloned().collect::>(); - - let is_in = join.on.test(&row, &join_table); - - if is_in { - join_table.rows.push(row); - } - } - } - _ => todo!(), - } - } - Ok(join_table) - } -} - -impl Shadow for SelectInner { - type Result = anyhow::Result; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - if let Some(from) = &self.from { - let mut join_table = from.shadow(env)?; - let col_count = join_table.columns().count(); - for row in &mut join_table.rows { - assert_eq!( - row.len(), - col_count, - "Row length does not match column length after join" - ); - } - let join_clone = join_table.clone(); - - join_table - .rows - .retain(|row| self.where_clause.test(row, &join_clone)); - - if self.distinctness == Distinctness::Distinct { - join_table.rows.sort_unstable(); - join_table.rows.dedup(); - } - - Ok(join_table) - } else { - assert!(self - .columns - .iter() - .all(|col| matches!(col, ResultColumn::Expr(_)))); - - // If `WHERE` is false, just return an empty table - if !self.where_clause.test(&[], &Table::anonymous(vec![])) { - return Ok(JoinTable { - tables: Vec::new(), - rows: Vec::new(), - }); - } - - // Compute the results of the column expressions and make a row - let mut row = Vec::new(); - for col in &self.columns { - match col { - ResultColumn::Expr(expr) => { - let value = expr.eval(&[], &Table::anonymous(vec![])); - if let Some(value) = value { - row.push(value); - } else { - return Err(anyhow::anyhow!( - "Failed to evaluate expression in free select ({})", - expr.0.format_with_context(&EmptyContext {}).unwrap() - )); - } - } - _ => unreachable!("Only expressions are allowed in free selects"), - } - } - - Ok(JoinTable { - tables: Vec::new(), - rows: vec![row], - }) - } - } -} - -impl Shadow for Select { - type Result = anyhow::Result>>; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - let first_result = self.body.select.shadow(env)?; - - let mut rows = first_result.rows; - - for compound in self.body.compounds.iter() { - let compound_results = compound.select.shadow(env)?; - - match compound.operator { - CompoundOperator::Union => { - // Union means we need to combine the results, removing duplicates - let mut new_rows = compound_results.rows; - new_rows.extend(rows.clone()); - new_rows.sort_unstable(); - new_rows.dedup(); - rows = new_rows; - } - CompoundOperator::UnionAll => { - // Union all means we just concatenate the results - rows.extend(compound_results.rows.into_iter()); - } - } - } - - Ok(rows) - } -} - impl Select { pub fn to_sql_ast(&self) -> ast::Select { ast::Select { with: None, body: ast::SelectBody { - select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { + select: ast::OneSelect::Select { distinctness: if self.body.select.distinctness == Distinctness::Distinct { Some(ast::Distinctness::Distinct) } else { @@ -411,77 +236,81 @@ impl Select { .iter() .map(|col| match col { ResultColumn::Expr(expr) => { - ast::ResultColumn::Expr(expr.0.clone(), None) + ast::ResultColumn::Expr(expr.0.clone().into_boxed(), None) } ResultColumn::Star => ast::ResultColumn::Star, ResultColumn::Column(name) => ast::ResultColumn::Expr( - ast::Expr::Id(ast::Name::Ident(name.clone())), + ast::Expr::Id(ast::Name::Ident(name.clone())).into_boxed(), None, ), }) .collect(), from: self.body.select.from.as_ref().map(|f| f.to_sql_ast()), - where_clause: Some(self.body.select.where_clause.0.clone()), + where_clause: Some(self.body.select.where_clause.0.clone().into_boxed()), group_by: None, - window_clause: None, - }))), - compounds: Some( - self.body - .compounds - .iter() - .map(|compound| ast::CompoundSelect { - operator: match compound.operator { - CompoundOperator::Union => ast::CompoundOperator::Union, - CompoundOperator::UnionAll => ast::CompoundOperator::UnionAll, - }, - select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { - distinctness: Some(compound.select.distinctness), - columns: compound - .select - .columns - .iter() - .map(|col| match col { - ResultColumn::Expr(expr) => { - ast::ResultColumn::Expr(expr.0.clone(), None) - } - ResultColumn::Star => ast::ResultColumn::Star, - ResultColumn::Column(name) => ast::ResultColumn::Expr( - ast::Expr::Id(ast::Name::Ident(name.clone())), - None, - ), - }) - .collect(), - from: compound.select.from.as_ref().map(|f| f.to_sql_ast()), - where_clause: Some(compound.select.where_clause.0.clone()), - group_by: None, - window_clause: None, - }))), - }) - .collect(), - ), - }, - order_by: self.body.select.order_by.as_ref().map(|o| { - o.columns + window_clause: Vec::new(), + }, + compounds: self + .body + .compounds .iter() - .map(|(name, order)| ast::SortedColumn { - expr: ast::Expr::Id(ast::Name::Ident(name.clone())), - order: match order { - SortOrder::Asc => Some(ast::SortOrder::Asc), - SortOrder::Desc => Some(ast::SortOrder::Desc), + .map(|compound| ast::CompoundSelect { + operator: match compound.operator { + CompoundOperator::Union => ast::CompoundOperator::Union, + CompoundOperator::UnionAll => ast::CompoundOperator::UnionAll, + }, + select: ast::OneSelect::Select { + distinctness: Some(compound.select.distinctness), + columns: compound + .select + .columns + .iter() + .map(|col| match col { + ResultColumn::Expr(expr) => { + ast::ResultColumn::Expr(expr.0.clone().into_boxed(), None) + } + ResultColumn::Star => ast::ResultColumn::Star, + ResultColumn::Column(name) => ast::ResultColumn::Expr( + ast::Expr::Id(ast::Name::Ident(name.clone())).into_boxed(), + None, + ), + }) + .collect(), + from: compound.select.from.as_ref().map(|f| f.to_sql_ast()), + where_clause: Some(compound.select.where_clause.0.clone().into_boxed()), + group_by: None, + window_clause: Vec::new(), }, - nulls: None, }) - .collect() - }), - limit: self.limit.map(|l| { - Box::new(ast::Limit { - expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())), - offset: None, + .collect(), + }, + order_by: self + .body + .select + .order_by + .as_ref() + .map(|o| { + o.columns + .iter() + .map(|(name, order)| ast::SortedColumn { + expr: ast::Expr::Id(ast::Name::Ident(name.clone())).into_boxed(), + order: match order { + SortOrder::Asc => Some(ast::SortOrder::Asc), + SortOrder::Desc => Some(ast::SortOrder::Desc), + }, + nulls: None, + }) + .collect() }) + .unwrap_or(Vec::new()), + limit: self.limit.map(|l| ast::Limit { + expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())).into_boxed(), + offset: None, }), } } } + impl Display for Select { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.to_sql_ast().to_fmt_with_context(f, &EmptyContext {}) diff --git a/sql_generation/model/query/transaction.rs b/sql_generation/model/query/transaction.rs index a73fb076e..2280357fa 100644 --- a/sql_generation/model/query/transaction.rs +++ b/sql_generation/model/query/transaction.rs @@ -2,8 +2,6 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct Begin { pub(crate) immediate: bool, @@ -15,32 +13,6 @@ pub(crate) struct Commit; #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct Rollback; -impl Shadow for Begin { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - tables.snapshot = Some(tables.tables.clone()); - vec![] - } -} - -impl Shadow for Commit { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - tables.snapshot = None; - vec![] - } -} - -impl Shadow for Rollback { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if let Some(tables_) = tables.snapshot.take() { - tables.tables = tables_; - } - vec![] - } -} - impl Display for Begin { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "BEGIN {}", if self.immediate { "IMMEDIATE" } else { "" }) diff --git a/sql_generation/model/query/update.rs b/sql_generation/model/query/update.rs index a4cc13fa8..c7c3a5a58 100644 --- a/sql_generation/model/query/update.rs +++ b/sql_generation/model/query/update.rs @@ -2,7 +2,7 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; +use crate::model::table::SimValue; use super::predicate::Predicate; @@ -19,43 +19,6 @@ impl Update { } } -impl Shadow for Update { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - let table = tables.tables.iter_mut().find(|t| t.name == self.table); - - let table = if let Some(table) = table { - table - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. UPDATE statement ignored.", - self.table - )); - }; - - let t2 = table.clone(); - for row in table - .rows - .iter_mut() - .filter(|r| self.predicate.test(r, &t2)) - { - for (column, set_value) in &self.set_values { - if let Some((idx, _)) = table - .columns - .iter() - .enumerate() - .find(|(_, c)| &c.name == column) - { - row[idx] = set_value.clone(); - } - } - } - - Ok(vec![]) - } -} - impl Display for Update { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "UPDATE {} SET ", self.table)?; From 0c1228b48474a55455ab04dc864f47ab61ddfb5b Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 20:11:01 -0300 Subject: [PATCH 61/73] add Generation context trait to decouple Simulator specific code --- parser/src/ast.rs | 10 +- sql_generation/generation/expr.rs | 43 +- sql_generation/generation/mod.rs | 59 +- sql_generation/generation/plan.rs | 833 ---------- sql_generation/generation/predicate/mod.rs | 4 +- sql_generation/generation/predicate/unary.rs | 1 + sql_generation/generation/property.rs | 1533 ------------------ sql_generation/generation/query.rs | 132 +- sql_generation/generation/table.rs | 6 +- sql_generation/model/mod.rs | 2 - sql_generation/model/query/create.rs | 4 +- sql_generation/model/query/create_index.rs | 24 +- sql_generation/model/query/delete.rs | 6 +- sql_generation/model/query/drop.rs | 4 +- sql_generation/model/query/insert.rs | 4 +- sql_generation/model/query/mod.rs | 20 +- sql_generation/model/query/predicate.rs | 23 +- sql_generation/model/query/select.rs | 67 +- sql_generation/model/query/transaction.rs | 8 +- sql_generation/model/query/update.rs | 8 +- sql_generation/model/table.rs | 41 +- 21 files changed, 223 insertions(+), 2609 deletions(-) delete mode 100644 sql_generation/generation/plan.rs delete mode 100644 sql_generation/generation/property.rs diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 44a427dcb..73adbf781 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -3,6 +3,8 @@ pub mod fmt; use strum_macros::{EnumIter, EnumString}; +use crate::ast::fmt::ToTokens; + /// `?` or `$` Prepared statement arg placeholder(s) #[derive(Default)] pub struct ParameterInfo { @@ -1188,7 +1190,7 @@ bitflags::bitflags! { } /// Sort orders -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum SortOrder { /// `ASC` @@ -1197,6 +1199,12 @@ pub enum SortOrder { Desc, } +impl core::fmt::Display for SortOrder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_fmt(f) + } +} + /// `NULLS FIRST` or `NULLS LAST` #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] diff --git a/sql_generation/generation/expr.rs b/sql_generation/generation/expr.rs index c5e33758c..c07d81414 100644 --- a/sql_generation/generation/expr.rs +++ b/sql_generation/generation/expr.rs @@ -5,7 +5,7 @@ use turso_parser::ast::{ use crate::{ generation::{ frequency, gen_random_text, one_of, pick, pick_index, Arbitrary, ArbitraryFrom, - ArbitrarySizedFrom, + ArbitrarySizedFrom, GenerationContext, }, model::table::SimValue, }; @@ -58,8 +58,8 @@ where } // Freestyling generation -impl ArbitrarySizedFrom<&SimulatorEnv> for Expr { - fn arbitrary_sized_from(rng: &mut R, t: &SimulatorEnv, size: usize) -> Self { +impl ArbitrarySizedFrom<&C> for Expr { + fn arbitrary_sized_from(rng: &mut R, t: &C, size: usize) -> Self { frequency( vec![ ( @@ -199,28 +199,11 @@ impl Arbitrary for Type { } } -struct CollateName(String); - -impl Arbitrary for CollateName { - fn arbitrary(rng: &mut R) -> Self { - let choice = rng.random_range(0..3); - CollateName( - match choice { - 0 => "BINARY", - 1 => "RTRIM", - 2 => "NOCASE", - _ => unreachable!(), - } - .to_string(), - ) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for QualifiedName { - fn arbitrary_from(rng: &mut R, t: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for QualifiedName { + fn arbitrary_from(rng: &mut R, t: &C) -> Self { // TODO: for now just generate table name - let table_idx = pick_index(t.tables.len(), rng); - let table = &t.tables[table_idx]; + let table_idx = pick_index(t.tables().len(), rng); + let table = &t.tables()[table_idx]; // TODO: for now forego alias Self { db_name: None, @@ -230,8 +213,8 @@ impl ArbitraryFrom<&SimulatorEnv> for QualifiedName { } } -impl ArbitraryFrom<&SimulatorEnv> for LikeOperator { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for LikeOperator { + fn arbitrary_from(rng: &mut R, _t: &C) -> Self { let choice = rng.random_range(0..4); match choice { 0 => LikeOperator::Glob, @@ -244,8 +227,8 @@ impl ArbitraryFrom<&SimulatorEnv> for LikeOperator { } // Current implementation does not take into account the columns affinity nor if table is Strict -impl ArbitraryFrom<&SimulatorEnv> for ast::Literal { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for ast::Literal { + fn arbitrary_from(rng: &mut R, _t: &C) -> Self { loop { let choice = rng.random_range(0..5); let lit = match choice { @@ -282,8 +265,8 @@ impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr { } } -impl ArbitraryFrom<&SimulatorEnv> for UnaryOperator { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for UnaryOperator { + fn arbitrary_from(rng: &mut R, _t: &C) -> Self { let choice = rng.random_range(0..4); match choice { 0 => Self::BitwiseNot, diff --git a/sql_generation/generation/mod.rs b/sql_generation/generation/mod.rs index 44ae7f34d..d9b7e0cbc 100644 --- a/sql_generation/generation/mod.rs +++ b/sql_generation/generation/mod.rs @@ -3,13 +3,24 @@ use std::{iter::Sum, ops::SubAssign}; use anarchist_readable_name_generator_lib::readable_name_custom; use rand::{distr::uniform::SampleUniform, Rng}; -mod expr; -pub mod plan; -mod predicate; -pub mod property; +use crate::model::table::Table; + +pub mod expr; +pub mod predicate; pub mod query; pub mod table; +pub struct Opts { + /// Indexes enabled + indexes: bool, +} + +/// Trait used to provide context to generation functions +pub trait GenerationContext { + fn tables(&self) -> &Vec
; + fn opts(&self) -> &Opts; +} + type ArbitraryFromFunc<'a, R, T> = Box T + 'a>; type Choice<'a, R, T> = (usize, Box Option + 'a>); @@ -66,11 +77,7 @@ pub trait ArbitraryFromMaybe { /// the operations we require for the implementation. // todo: switch to a simpler type signature that can accommodate all integer and float types, which // should be enough for our purposes. -pub(crate) fn frequency< - T, - R: Rng, - N: Sum + PartialOrd + Copy + Default + SampleUniform + SubAssign, ->( +pub fn frequency( choices: Vec<(N, ArbitraryFromFunc)>, rng: &mut R, ) -> T { @@ -88,7 +95,7 @@ pub(crate) fn frequency< } /// one_of is a helper function for composing different generators with equal probability of occurrence. -pub(crate) fn one_of(choices: Vec>, rng: &mut R) -> T { +pub fn one_of(choices: Vec>, rng: &mut R) -> T { let index = rng.random_range(0..choices.len()); choices[index](rng) } @@ -96,7 +103,7 @@ pub(crate) fn one_of(choices: Vec>, rng: &mut /// backtrack is a helper function for composing different "failable" generators. /// The function takes a list of functions that return an Option, along with number of retries /// to make before giving up. -pub(crate) fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { +pub fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { loop { // If there are no more choices left, we give up let choices_ = choices @@ -122,24 +129,20 @@ pub(crate) fn backtrack(mut choices: Vec>, rng: &mut R) } /// pick is a helper function for uniformly picking a random element from a slice -pub(crate) fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { +pub fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { let index = rng.random_range(0..choices.len()); &choices[index] } /// pick_index is typically used for picking an index from a slice to later refer to the element /// at that index. -pub(crate) fn pick_index(choices: usize, rng: &mut R) -> usize { +pub fn pick_index(choices: usize, rng: &mut R) -> usize { rng.random_range(0..choices) } /// pick_n_unique is a helper function for uniformly picking N unique elements from a range. /// The elements themselves are usize, typically representing indices. -pub(crate) fn pick_n_unique( - range: std::ops::Range, - n: usize, - rng: &mut R, -) -> Vec { +pub fn pick_n_unique(range: std::ops::Range, n: usize, rng: &mut R) -> Vec { use rand::seq::SliceRandom; let mut items: Vec = range.collect(); items.shuffle(rng); @@ -148,7 +151,7 @@ pub(crate) fn pick_n_unique( /// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random /// readable names for tables, columns, text values etc. -pub(crate) fn gen_random_text(rng: &mut T) -> String { +pub fn gen_random_text(rng: &mut T) -> String { let big_text = rng.random_ratio(1, 1000); if big_text { // let max_size: u64 = 2 * 1024 * 1024 * 1024; @@ -164,3 +167,21 @@ pub(crate) fn gen_random_text(rng: &mut T) -> String { name.replace("-", "_") } } + +pub fn pick_unique( + items: &[T], + count: usize, + rng: &mut impl rand::Rng, +) -> Vec +where + ::Owned: PartialEq, +{ + let mut picked: Vec = Vec::new(); + while picked.len() < count { + let item = pick(items, rng); + if !picked.contains(&item.to_owned()) { + picked.push(item.to_owned()); + } + } + picked +} diff --git a/sql_generation/generation/plan.rs b/sql_generation/generation/plan.rs deleted file mode 100644 index eac9359b3..000000000 --- a/sql_generation/generation/plan.rs +++ /dev/null @@ -1,833 +0,0 @@ -use std::{ - collections::HashSet, - fmt::{Debug, Display}, - path::Path, - sync::Arc, - vec, -}; - -use serde::{Deserialize, Serialize}; - -use turso_core::{Connection, Result, StepResult}; - -use crate::{ - generation::query::SelectFree, - model::{ - query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Query, Select}, - table::SimValue, - }, - runner::{ - env::{SimConnection, SimulationType, SimulatorTables}, - io::SimulatorIO, - }, - SimulatorEnv, -}; - -use crate::generation::{frequency, Arbitrary, ArbitraryFrom}; - -use super::property::{remaining, Property}; - -pub(crate) type ResultSet = Result>>; - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct InteractionPlan { - pub(crate) plan: Vec, -} - -impl InteractionPlan { - /// Compute via diff computes a a plan from a given `.plan` file without the need to parse - /// sql. This is possible because there are two versions of the plan file, one that is human - /// readable and one that is serialized as JSON. Under watch mode, the users will be able to - /// delete interactions from the human readable file, and this function uses the JSON file as - /// a baseline to detect with interactions were deleted and constructs the plan from the - /// remaining interactions. - pub(crate) fn compute_via_diff(plan_path: &Path) -> Vec> { - let interactions = std::fs::read_to_string(plan_path).unwrap(); - let interactions = interactions.lines().collect::>(); - - let plan: InteractionPlan = serde_json::from_str( - std::fs::read_to_string(plan_path.with_extension("json")) - .unwrap() - .as_str(), - ) - .unwrap(); - - let mut plan = plan - .plan - .into_iter() - .map(|i| i.interactions()) - .collect::>(); - - let (mut i, mut j) = (0, 0); - - while i < interactions.len() && j < plan.len() { - if interactions[i].starts_with("-- begin") - || interactions[i].starts_with("-- end") - || interactions[i].is_empty() - { - i += 1; - continue; - } - - // interactions[i] is the i'th line in the human readable plan - // plan[j][k] is the k'th interaction in the j'th property - let mut k = 0; - - while k < plan[j].len() { - if i >= interactions.len() { - let _ = plan.split_off(j + 1); - let _ = plan[j].split_off(k); - break; - } - tracing::error!("Comparing '{}' with '{}'", interactions[i], plan[j][k]); - if interactions[i].contains(plan[j][k].to_string().as_str()) { - i += 1; - k += 1; - } else { - plan[j].remove(k); - panic!("Comparing '{}' with '{}'", interactions[i], plan[j][k]); - } - } - - if plan[j].is_empty() { - plan.remove(j); - } else { - j += 1; - } - } - let _ = plan.split_off(j); - plan - } -} - -pub(crate) struct InteractionPlanState { - pub(crate) stack: Vec, - pub(crate) interaction_pointer: usize, - pub(crate) secondary_pointer: usize, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum Interactions { - Property(Property), - Query(Query), - Fault(Fault), -} - -impl Interactions { - pub(crate) fn name(&self) -> Option<&str> { - match self { - Interactions::Property(property) => Some(property.name()), - Interactions::Query(_) => None, - Interactions::Fault(_) => None, - } - } - - pub(crate) fn interactions(&self) -> Vec { - match self { - Interactions::Property(property) => property.interactions(), - Interactions::Query(query) => vec![Interaction::Query(query.clone())], - Interactions::Fault(fault) => vec![Interaction::Fault(fault.clone())], - } - } -} - -impl Interactions { - pub(crate) fn dependencies(&self) -> HashSet { - match self { - Interactions::Property(property) => { - property - .interactions() - .iter() - .fold(HashSet::new(), |mut acc, i| match i { - Interaction::Query(q) => { - acc.extend(q.dependencies()); - acc - } - _ => acc, - }) - } - Interactions::Query(query) => query.dependencies(), - Interactions::Fault(_) => HashSet::new(), - } - } - - pub(crate) fn uses(&self) -> Vec { - match self { - Interactions::Property(property) => { - property - .interactions() - .iter() - .fold(vec![], |mut acc, i| match i { - Interaction::Query(q) => { - acc.extend(q.uses()); - acc - } - _ => acc, - }) - } - Interactions::Query(query) => query.uses(), - Interactions::Fault(_) => vec![], - } - } -} - -impl Display for InteractionPlan { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for interactions in &self.plan { - match interactions { - Interactions::Property(property) => { - let name = property.name(); - writeln!(f, "-- begin testing '{name}'")?; - for interaction in property.interactions() { - write!(f, "\t")?; - - match interaction { - Interaction::Query(query) => writeln!(f, "{query};")?, - Interaction::Assumption(assumption) => { - writeln!(f, "-- ASSUME {};", assumption.name)? - } - Interaction::Assertion(assertion) => { - writeln!(f, "-- ASSERT {};", assertion.name)? - } - Interaction::Fault(fault) => writeln!(f, "-- FAULT '{fault}';")?, - Interaction::FsyncQuery(query) => { - writeln!(f, "-- FSYNC QUERY;")?; - writeln!(f, "{query};")?; - writeln!(f, "{query};")? - } - Interaction::FaultyQuery(query) => { - writeln!(f, "{query}; -- FAULTY QUERY")? - } - } - } - writeln!(f, "-- end testing '{name}'")?; - } - Interactions::Fault(fault) => { - writeln!(f, "-- FAULT '{fault}'")?; - } - Interactions::Query(query) => { - writeln!(f, "{query};")?; - } - } - } - - Ok(()) - } -} - -#[derive(Debug, Clone, Copy)] -pub(crate) struct InteractionStats { - pub(crate) read_count: usize, - pub(crate) write_count: usize, - pub(crate) delete_count: usize, - pub(crate) update_count: usize, - pub(crate) create_count: usize, - pub(crate) create_index_count: usize, - pub(crate) drop_count: usize, - pub(crate) begin_count: usize, - pub(crate) commit_count: usize, - pub(crate) rollback_count: usize, -} - -impl Display for InteractionStats { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Read: {}, Write: {}, Delete: {}, Update: {}, Create: {}, CreateIndex: {}, Drop: {}, Begin: {}, Commit: {}, Rollback: {}", - self.read_count, - self.write_count, - self.delete_count, - self.update_count, - self.create_count, - self.create_index_count, - self.drop_count, - self.begin_count, - self.commit_count, - self.rollback_count, - ) - } -} - -#[derive(Debug)] -pub(crate) enum Interaction { - Query(Query), - Assumption(Assertion), - Assertion(Assertion), - Fault(Fault), - /// Will attempt to run any random query. However, when the connection tries to sync it will - /// close all connections and reopen the database and assert that no data was lost - FsyncQuery(Query), - FaultyQuery(Query), -} - -impl Display for Interaction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Query(query) => write!(f, "{query}"), - Self::Assumption(assumption) => write!(f, "ASSUME {}", assumption.name), - Self::Assertion(assertion) => write!(f, "ASSERT {}", assertion.name), - Self::Fault(fault) => write!(f, "FAULT '{fault}'"), - Self::FsyncQuery(query) => write!(f, "{query}"), - Self::FaultyQuery(query) => write!(f, "{query}; -- FAULTY QUERY"), - } - } -} - -type AssertionFunc = dyn Fn(&Vec, &mut SimulatorEnv) -> Result>; - -enum AssertionAST { - Pick(), -} - -pub(crate) struct Assertion { - pub(crate) func: Box, - pub(crate) name: String, // For display purposes in the plan -} - -impl Debug for Assertion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Assertion") - .field("name", &self.name) - .finish() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum Fault { - Disconnect, - ReopenDatabase, -} - -impl Display for Fault { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Fault::Disconnect => write!(f, "DISCONNECT"), - Fault::ReopenDatabase => write!(f, "REOPEN_DATABASE"), - } - } -} - -impl InteractionPlan { - pub(crate) fn new() -> Self { - Self { plan: Vec::new() } - } - - pub(crate) fn stats(&self) -> InteractionStats { - let mut stats = InteractionStats { - read_count: 0, - write_count: 0, - delete_count: 0, - update_count: 0, - create_count: 0, - create_index_count: 0, - drop_count: 0, - begin_count: 0, - commit_count: 0, - rollback_count: 0, - }; - - fn query_stat(q: &Query, stats: &mut InteractionStats) { - match q { - Query::Select(_) => stats.read_count += 1, - Query::Insert(_) => stats.write_count += 1, - Query::Delete(_) => stats.delete_count += 1, - Query::Create(_) => stats.create_count += 1, - Query::Drop(_) => stats.drop_count += 1, - Query::Update(_) => stats.update_count += 1, - Query::CreateIndex(_) => stats.create_index_count += 1, - Query::Begin(_) => stats.begin_count += 1, - Query::Commit(_) => stats.commit_count += 1, - Query::Rollback(_) => stats.rollback_count += 1, - } - } - for interactions in &self.plan { - match interactions { - Interactions::Property(property) => { - for interaction in &property.interactions() { - if let Interaction::Query(query) = interaction { - query_stat(query, &mut stats); - } - } - } - Interactions::Query(query) => { - query_stat(query, &mut stats); - } - Interactions::Fault(_) => {} - } - } - - stats - } -} - -impl ArbitraryFrom<&mut SimulatorEnv> for InteractionPlan { - fn arbitrary_from(rng: &mut R, env: &mut SimulatorEnv) -> Self { - let mut plan = InteractionPlan::new(); - - let num_interactions = env.opts.max_interactions; - - // First create at least one table - let create_query = Create::arbitrary(rng); - env.tables.push(create_query.table.clone()); - - plan.plan - .push(Interactions::Query(Query::Create(create_query))); - - while plan.plan.len() < num_interactions { - tracing::debug!( - "Generating interaction {}/{}", - plan.plan.len(), - num_interactions - ); - let interactions = Interactions::arbitrary_from(rng, (env, plan.stats())); - interactions.shadow(&mut env.tables); - plan.plan.push(interactions); - } - - tracing::info!("Generated plan with {} interactions", plan.plan.len()); - plan - } -} - -impl Interaction { - pub(crate) fn execute_query(&self, conn: &mut Arc, _io: &SimulatorIO) -> ResultSet { - if let Self::Query(query) = self { - let query_str = query.to_string(); - let rows = conn.query(&query_str); - if rows.is_err() { - let err = rows.err(); - tracing::debug!( - "Error running query '{}': {:?}", - &query_str[0..query_str.len().min(4096)], - err - ); - if let Some(turso_core::LimboError::ParseError(e)) = err { - panic!("Unexpected parse error: {e}"); - } - return Err(err.unwrap()); - } - let rows = rows?; - assert!(rows.is_some()); - let mut rows = rows.unwrap(); - let mut out = Vec::new(); - while let Ok(row) = rows.step() { - match row { - StepResult::Row => { - let row = rows.row().unwrap(); - let mut r = Vec::new(); - for v in row.get_values() { - let v = v.into(); - r.push(v); - } - out.push(r); - } - StepResult::IO => { - rows.run_once().unwrap(); - } - StepResult::Interrupt => {} - StepResult::Done => { - break; - } - StepResult::Busy => { - return Err(turso_core::LimboError::Busy); - } - } - } - - Ok(out) - } else { - unreachable!("unexpected: this function should only be called on queries") - } - } - - pub(crate) fn execute_assertion( - &self, - stack: &Vec, - env: &mut SimulatorEnv, - ) -> Result<()> { - match self { - Self::Assertion(assertion) => { - let result = assertion.func.as_ref()(stack, env); - match result { - Ok(Ok(())) => Ok(()), - Ok(Err(message)) => Err(turso_core::LimboError::InternalError(format!( - "Assertion '{}' failed: {}", - assertion.name, message - ))), - Err(err) => Err(turso_core::LimboError::InternalError(format!( - "Assertion '{}' execution error: {}", - assertion.name, err - ))), - } - } - _ => { - unreachable!("unexpected: this function should only be called on assertions") - } - } - } - - pub(crate) fn execute_assumption( - &self, - stack: &Vec, - env: &mut SimulatorEnv, - ) -> Result<()> { - match self { - Self::Assumption(assumption) => { - let result = assumption.func.as_ref()(stack, env); - match result { - Ok(Ok(())) => Ok(()), - Ok(Err(message)) => Err(turso_core::LimboError::InternalError(format!( - "Assumption '{}' failed: {}", - assumption.name, message - ))), - Err(err) => Err(turso_core::LimboError::InternalError(format!( - "Assumption '{}' execution error: {}", - assumption.name, err - ))), - } - } - _ => { - unreachable!("unexpected: this function should only be called on assumptions") - } - } - } - - pub(crate) fn execute_fault(&self, env: &mut SimulatorEnv, conn_index: usize) -> Result<()> { - match self { - Self::Fault(fault) => { - match fault { - Fault::Disconnect => { - if env.connections[conn_index].is_connected() { - env.connections[conn_index].disconnect(); - } else { - return Err(turso_core::LimboError::InternalError( - "connection already disconnected".into(), - )); - } - env.connections[conn_index] = SimConnection::Disconnected; - } - Fault::ReopenDatabase => { - reopen_database(env); - } - } - Ok(()) - } - _ => { - unreachable!("unexpected: this function should only be called on faults") - } - } - } - - pub(crate) fn execute_fsync_query( - &self, - conn: Arc, - env: &mut SimulatorEnv, - ) -> ResultSet { - if let Self::FsyncQuery(query) = self { - let query_str = query.to_string(); - let rows = conn.query(&query_str); - if rows.is_err() { - let err = rows.err(); - tracing::debug!( - "Error running query '{}': {:?}", - &query_str[0..query_str.len().min(4096)], - err - ); - return Err(err.unwrap()); - } - let mut rows = rows.unwrap().unwrap(); - let mut out = Vec::new(); - while let Ok(row) = rows.step() { - match row { - StepResult::Row => { - let row = rows.row().unwrap(); - let mut r = Vec::new(); - for v in row.get_values() { - let v = v.into(); - r.push(v); - } - out.push(r); - } - StepResult::IO => { - let syncing = { - let files = env.io.files.borrow(); - // TODO: currently assuming we only have 1 file that is syncing - files - .iter() - .any(|file| file.sync_completion.borrow().is_some()) - }; - if syncing { - reopen_database(env); - } else { - rows.run_once().unwrap(); - } - } - StepResult::Done => { - break; - } - StepResult::Busy => { - return Err(turso_core::LimboError::Busy); - } - StepResult::Interrupt => {} - } - } - - Ok(out) - } else { - unreachable!("unexpected: this function should only be called on queries") - } - } - - pub(crate) fn execute_faulty_query( - &self, - conn: &Arc, - env: &mut SimulatorEnv, - ) -> ResultSet { - use rand::Rng; - if let Self::FaultyQuery(query) = self { - let query_str = query.to_string(); - let rows = conn.query(&query_str); - if rows.is_err() { - let err = rows.err(); - tracing::debug!( - "Error running query '{}': {:?}", - &query_str[0..query_str.len().min(4096)], - err - ); - if let Some(turso_core::LimboError::ParseError(e)) = err { - panic!("Unexpected parse error: {e}"); - } - return Err(err.unwrap()); - } - let mut rows = rows.unwrap().unwrap(); - let mut out = Vec::new(); - let mut current_prob = 0.05; - let mut incr = 0.001; - loop { - let syncing = { - let files = env.io.files.borrow(); - files - .iter() - .any(|file| file.sync_completion.borrow().is_some()) - }; - let inject_fault = env.rng.gen_bool(current_prob); - // TODO: avoid for now injecting faults when syncing - if inject_fault && !syncing { - env.io.inject_fault(true); - } - - match rows.step()? { - StepResult::Row => { - let row = rows.row().unwrap(); - let mut r = Vec::new(); - for v in row.get_values() { - let v = v.into(); - r.push(v); - } - out.push(r); - } - StepResult::IO => { - rows.run_once()?; - current_prob += incr; - if current_prob > 1.0 { - current_prob = 1.0; - } else { - incr *= 1.01; - } - } - StepResult::Done => { - break; - } - StepResult::Busy => { - return Err(turso_core::LimboError::Busy); - } - StepResult::Interrupt => {} - } - } - - Ok(out) - } else { - unreachable!("unexpected: this function should only be called on queries") - } - } -} - -fn reopen_database(env: &mut SimulatorEnv) { - // 1. Close all connections without default checkpoint-on-close behavior - // to expose bugs related to how we handle WAL - let num_conns = env.connections.len(); - env.connections.clear(); - - // Clear all open files - // TODO: for correct reporting of faults we should get all the recorded numbers and transfer to the new file - env.io.files.borrow_mut().clear(); - - // 2. Re-open database - match env.type_ { - SimulationType::Differential => { - for _ in 0..num_conns { - env.connections.push(SimConnection::SQLiteConnection( - rusqlite::Connection::open(env.get_db_path()) - .expect("Failed to open SQLite connection"), - )); - } - } - SimulationType::Default | SimulationType::Doublecheck => { - env.db = None; - let db = match turso_core::Database::open_file( - env.io.clone(), - env.get_db_path().to_str().expect("path should be 'to_str'"), - false, - true, - ) { - Ok(db) => db, - Err(e) => { - tracing::error!( - "Failed to open database at {}: {}", - env.get_db_path().display(), - e - ); - panic!("Failed to open database: {e}"); - } - }; - - env.db = Some(db); - - for _ in 0..num_conns { - env.connections.push(SimConnection::LimboConnection( - env.db.as_ref().expect("db to be Some").connect().unwrap(), - )); - } - } - }; -} - -fn random_create(rng: &mut R, env: &SimulatorEnv) -> Interactions { - let mut create = Create::arbitrary(rng); - while env.tables.iter().any(|t| t.name == create.table.name) { - create = Create::arbitrary(rng); - } - Interactions::Query(Query::Create(create)) -} - -fn random_read(rng: &mut R, env: &SimulatorEnv) -> Interactions { - Interactions::Query(Query::Select(Select::arbitrary_from(rng, env))) -} - -fn random_expr(rng: &mut R, env: &SimulatorEnv) -> Interactions { - Interactions::Query(Query::Select(SelectFree::arbitrary_from(rng, env).0)) -} - -fn random_write(rng: &mut R, env: &SimulatorEnv) -> Interactions { - Interactions::Query(Query::Insert(Insert::arbitrary_from(rng, env))) -} - -fn random_delete(rng: &mut R, env: &SimulatorEnv) -> Interactions { - Interactions::Query(Query::Delete(Delete::arbitrary_from(rng, env))) -} - -fn random_update(rng: &mut R, env: &SimulatorEnv) -> Interactions { - Interactions::Query(Query::Update(Update::arbitrary_from(rng, env))) -} - -fn random_drop(rng: &mut R, env: &SimulatorEnv) -> Interactions { - Interactions::Query(Query::Drop(Drop::arbitrary_from(rng, env))) -} - -fn random_create_index(rng: &mut R, env: &SimulatorEnv) -> Option { - if env.tables.is_empty() { - return None; - } - let mut create_index = CreateIndex::arbitrary_from(rng, env); - while env - .tables - .iter() - .find(|t| t.name == create_index.table_name) - .expect("table should exist") - .indexes - .iter() - .any(|i| i == &create_index.index_name) - { - create_index = CreateIndex::arbitrary_from(rng, env); - } - - Some(Interactions::Query(Query::CreateIndex(create_index))) -} - -fn random_fault(rng: &mut R, env: &SimulatorEnv) -> Interactions { - let faults = if env.opts.disable_reopen_database { - vec![Fault::Disconnect] - } else { - vec![Fault::Disconnect, Fault::ReopenDatabase] - }; - let fault = faults[rng.random_range(0..faults.len())].clone(); - Interactions::Fault(fault) -} - -impl ArbitraryFrom<(&SimulatorEnv, InteractionStats)> for Interactions { - fn arbitrary_from( - rng: &mut R, - (env, stats): (&SimulatorEnv, InteractionStats), - ) -> Self { - let remaining_ = remaining(env, &stats); - frequency( - vec![ - ( - f64::min(remaining_.read, remaining_.write) + remaining_.create, - Box::new(|rng: &mut R| { - Interactions::Property(Property::arbitrary_from(rng, (env, &stats))) - }), - ), - ( - remaining_.read, - Box::new(|rng: &mut R| random_read(rng, env)), - ), - ( - remaining_.read / 3.0, - Box::new(|rng: &mut R| random_expr(rng, env)), - ), - ( - remaining_.write, - Box::new(|rng: &mut R| random_write(rng, env)), - ), - ( - remaining_.create, - Box::new(|rng: &mut R| random_create(rng, env)), - ), - ( - remaining_.create_index, - Box::new(|rng: &mut R| { - if let Some(interaction) = random_create_index(rng, env) { - interaction - } else { - // if no tables exist, we can't create an index, so fallback to creating a table - random_create(rng, env) - } - }), - ), - ( - remaining_.delete, - Box::new(|rng: &mut R| random_delete(rng, env)), - ), - ( - remaining_.update, - Box::new(|rng: &mut R| random_update(rng, env)), - ), - ( - // remaining_.drop, - 0.0, - Box::new(|rng: &mut R| random_drop(rng, env)), - ), - ( - remaining_ - .read - .min(remaining_.write) - .min(remaining_.create) - .max(1.0), - Box::new(|rng: &mut R| random_fault(rng, env)), - ), - ], - rng, - ) - } -} diff --git a/sql_generation/generation/predicate/mod.rs b/sql_generation/generation/predicate/mod.rs index 0a06dead0..b919ad0bd 100644 --- a/sql_generation/generation/predicate/mod.rs +++ b/sql_generation/generation/predicate/mod.rs @@ -8,8 +8,8 @@ use crate::model::{ use super::{one_of, ArbitraryFrom}; -mod binary; -mod unary; +pub mod binary; +pub mod unary; #[derive(Debug)] struct CompoundPredicate(Predicate); diff --git a/sql_generation/generation/predicate/unary.rs b/sql_generation/generation/predicate/unary.rs index 6800740d7..62c6d7d65 100644 --- a/sql_generation/generation/predicate/unary.rs +++ b/sql_generation/generation/predicate/unary.rs @@ -64,6 +64,7 @@ impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue { } } +#[allow(dead_code)] pub struct BitNotValue(pub SimValue); impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue { diff --git a/sql_generation/generation/property.rs b/sql_generation/generation/property.rs deleted file mode 100644 index 495f75ec7..000000000 --- a/sql_generation/generation/property.rs +++ /dev/null @@ -1,1533 +0,0 @@ -use serde::{Deserialize, Serialize}; -use turso_core::{types, LimboError}; -use turso_parser::ast::{self}; - -use crate::{ - model::{ - query::{ - predicate::Predicate, - select::{ - CompoundOperator, CompoundSelect, Distinctness, ResultColumn, SelectBody, - SelectInner, - }, - transaction::{Begin, Commit, Rollback}, - update::Update, - Create, Delete, Drop, Insert, Query, Select, - }, - table::SimValue, - }, - runner::env::SimulatorEnv, -}; - -use super::{ - frequency, pick, pick_index, - plan::{Assertion, Interaction, InteractionStats, ResultSet}, - ArbitraryFrom, -}; - -/// Properties are representations of executable specifications -/// about the database behavior. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum Property { - /// Insert-Select is a property in which the inserted row - /// must be in the resulting rows of a select query that has a - /// where clause that matches the inserted row. - /// The execution of the property is as follows - /// INSERT INTO VALUES (...) - /// I_0 - /// I_1 - /// ... - /// I_n - /// SELECT * FROM WHERE - /// The interactions in the middle has the following constraints; - /// - There will be no errors in the middle interactions. - /// - The inserted row will not be deleted. - /// - The inserted row will not be updated. - /// - The table `t` will not be renamed, dropped, or altered. - InsertValuesSelect { - /// The insert query - insert: Insert, - /// Selected row index - row_index: usize, - /// Additional interactions in the middle of the property - queries: Vec, - /// The select query - select: Select, - /// Interactive query information if any - interactive: Option, - }, - /// ReadYourUpdatesBack is a property in which the updated rows - /// must be in the resulting rows of a select query that has a - /// where clause that matches the updated row. - /// The execution of the property is as follows - /// UPDATE SET WHERE - /// SELECT FROM WHERE - /// These interactions are executed in immediate succession - /// just to verify the property that our updates did what they - /// were supposed to do. - ReadYourUpdatesBack { - update: Update, - select: Select, - }, - /// TableHasExpectedContent is a property in which the table - /// must have the expected content, i.e. all the insertions and - /// updates and deletions should have been persisted in the way - /// we think they were. - /// The execution of the property is as follows - /// SELECT * FROM - /// ASSERT - TableHasExpectedContent { - table: String, - }, - /// Double Create Failure is a property in which creating - /// the same table twice leads to an error. - /// The execution of the property is as follows - /// CREATE TABLE (...) - /// I_0 - /// I_1 - /// ... - /// I_n - /// CREATE TABLE (...) -> Error - /// The interactions in the middle has the following constraints; - /// - There will be no errors in the middle interactions. - /// - Table `t` will not be renamed or dropped. - DoubleCreateFailure { - /// The create query - create: Create, - /// Additional interactions in the middle of the property - queries: Vec, - }, - /// Select Limit is a property in which the select query - /// has a limit clause that is respected by the query. - /// The execution of the property is as follows - /// SELECT * FROM WHERE LIMIT - /// This property is a single-interaction property. - /// The interaction has the following constraints; - /// - The select query will respect the limit clause. - SelectLimit { - /// The select query - select: Select, - }, - /// Delete-Select is a property in which the deleted row - /// must not be in the resulting rows of a select query that has a - /// where clause that matches the deleted row. In practice, `p1` of - /// the delete query will be used as the predicate for the select query, - /// hence the select should return NO ROWS. - /// The execution of the property is as follows - /// DELETE FROM WHERE - /// I_0 - /// I_1 - /// ... - /// I_n - /// SELECT * FROM WHERE - /// The interactions in the middle has the following constraints; - /// - There will be no errors in the middle interactions. - /// - A row that holds for the predicate will not be inserted. - /// - The table `t` will not be renamed, dropped, or altered. - DeleteSelect { - table: String, - predicate: Predicate, - queries: Vec, - }, - /// Drop-Select is a property in which selecting from a dropped table - /// should result in an error. - /// The execution of the property is as follows - /// DROP TABLE - /// I_0 - /// I_1 - /// ... - /// I_n - /// SELECT * FROM WHERE -> Error - /// The interactions in the middle has the following constraints; - /// - There will be no errors in the middle interactions. - /// - The table `t` will not be created, no table will be renamed to `t`. - DropSelect { - table: String, - queries: Vec, - select: Select, - }, - /// Select-Select-Optimizer is a property in which we test the optimizer by - /// running two equivalent select queries, one with `SELECT from ` - /// and the other with `SELECT * from WHERE `. As highlighted by - /// Rigger et al. in Non-Optimizing Reference Engine Construction(NoREC), SQLite - /// tends to optimize `where` statements while keeping the result column expressions - /// unoptimized. This property is used to test the optimizer. The property is successful - /// if the two queries return the same number of rows. - SelectSelectOptimizer { - table: String, - predicate: Predicate, - }, - /// Where-True-False-Null is a property that tests the boolean logic implementation - /// in the database. It relies on the fact that `P == true || P == false || P == null` should return true, - /// as SQLite uses a ternary logic system. This property is invented in "Finding Bugs in Database Systems via Query Partitioning" - /// by Rigger et al. and it is canonically called Ternary Logic Partitioning (TLP). - WhereTrueFalseNull { - select: Select, - predicate: Predicate, - }, - /// UNION-ALL-Preserves-Cardinality is a property that tests the UNION ALL operator - /// implementation in the database. It relies on the fact that `SELECT * FROM WHERE UNION ALL SELECT * FROM WHERE ` - /// should return the same number of rows as `SELECT FROM WHERE `. - /// > The property is succesfull when the UNION ALL of 2 select queries returns the same number of rows - /// > as the sum of the two select queries. - UNIONAllPreservesCardinality { - select: Select, - where_clause: Predicate, - }, - /// FsyncNoWait is a property which tests if we do not loose any data after not waiting for fsync. - /// - /// # Interactions - /// - Executes the `query` without waiting for fsync - /// - Drop all connections and Reopen the database - /// - Execute the `query` again - /// - Query tables to assert that the values were inserted - /// - FsyncNoWait { - query: Query, - tables: Vec, - }, - FaultyQuery { - query: Query, - tables: Vec, - }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InteractiveQueryInfo { - start_with_immediate: bool, - end_with_commit: bool, -} - -impl Property { - pub(crate) fn name(&self) -> &str { - match self { - Property::InsertValuesSelect { .. } => "Insert-Values-Select", - Property::ReadYourUpdatesBack { .. } => "Read-Your-Updates-Back", - Property::TableHasExpectedContent { .. } => "Table-Has-Expected-Content", - Property::DoubleCreateFailure { .. } => "Double-Create-Failure", - Property::SelectLimit { .. } => "Select-Limit", - Property::DeleteSelect { .. } => "Delete-Select", - Property::DropSelect { .. } => "Drop-Select", - Property::SelectSelectOptimizer { .. } => "Select-Select-Optimizer", - Property::WhereTrueFalseNull { .. } => "Where-True-False-Null", - Property::FsyncNoWait { .. } => "FsyncNoWait", - Property::FaultyQuery { .. } => "FaultyQuery", - Property::UNIONAllPreservesCardinality { .. } => "UNION-All-Preserves-Cardinality", - } - } - /// interactions construct a list of interactions, which is an executable representation of the property. - /// the requirement of property -> vec conversion emerges from the need to serialize the property, - /// and `interaction` cannot be serialized directly. - pub(crate) fn interactions(&self) -> Vec { - match self { - Property::TableHasExpectedContent { table } => { - let table = table.to_string(); - let table_name = table.clone(); - let assumption = Interaction::Assumption(Assertion { - name: format!("table {} exists", table.clone()), - func: Box::new(move |_: &Vec, env: &mut SimulatorEnv| { - if env.tables.iter().any(|t| t.name == table_name) { - Ok(Ok(())) - } else { - Ok(Err(format!("table {table_name} does not exist"))) - } - }), - }); - - let select_interaction = Interaction::Query(Query::Select(Select::simple( - table.clone(), - Predicate::true_(), - ))); - - let assertion = Interaction::Assertion(Assertion { - name: format!("table {} should have the expected content", table.clone()), - func: Box::new(move |stack: &Vec, env| { - let rows = stack.last().unwrap(); - let Ok(rows) = rows else { - return Ok(Err(format!("expected rows but got error: {rows:?}"))); - }; - let sim_table = env - .tables - .iter() - .find(|t| t.name == table) - .expect("table should be in enviroment"); - if rows.len() != sim_table.rows.len() { - return Ok(Err(format!( - "expected {} rows but got {} for table {}", - sim_table.rows.len(), - rows.len(), - table.clone() - ))); - } - for expected_row in sim_table.rows.iter() { - if !rows.contains(expected_row) { - return Ok(Err(format!( - "expected row {:?} not found in table {}", - expected_row, - table.clone() - ))); - } - } - Ok(Ok(())) - }), - }); - - vec![assumption, select_interaction, assertion] - } - Property::ReadYourUpdatesBack { update, select } => { - let table = update.table().to_string(); - let assumption = Interaction::Assumption(Assertion { - name: format!("table {} exists", table.clone()), - func: Box::new(move |_: &Vec, env: &mut SimulatorEnv| { - if env.tables.iter().any(|t| t.name == table.clone()) { - Ok(Ok(())) - } else { - Ok(Err(format!("table {} does not exist", table.clone()))) - } - }), - }); - - let update_interaction = Interaction::Query(Query::Update(update.clone())); - let select_interaction = Interaction::Query(Query::Select(select.clone())); - - let update = update.clone(); - - let table = update.table().to_string(); - - let assertion = Interaction::Assertion(Assertion { - name: format!( - "updated rows should be found and have the updated values for table {}", - table.clone() - ), - func: Box::new(move |stack: &Vec, _| { - let rows = stack.last().unwrap(); - match rows { - Ok(rows) => { - for row in rows { - for (i, (col, val)) in update.set_values.iter().enumerate() { - if &row[i] != val { - return Ok(Err(format!("updated row {} has incorrect value for column {col}: expected {val}, got {}", i, row[i]))); - } - } - } - Ok(Ok(())) - } - Err(err) => Err(LimboError::InternalError(err.to_string())), - } - }), - }); - - vec![ - assumption, - update_interaction, - select_interaction, - assertion, - ] - } - Property::InsertValuesSelect { - insert, - row_index, - queries, - select, - interactive, - } => { - let (table, values) = if let Insert::Values { table, values } = insert { - (table, values) - } else { - unreachable!( - "insert query should be Insert::Values for Insert-Values-Select property" - ) - }; - // Check that the insert query has at least 1 value - assert!( - !values.is_empty(), - "insert query should have at least 1 value" - ); - - // Pick a random row within the insert values - let row = values[*row_index].clone(); - - // Assume that the table exists - let assumption = Interaction::Assumption(Assertion { - name: format!("table {} exists", insert.table()), - func: Box::new({ - let table_name = table.clone(); - move |_: &Vec, env: &mut SimulatorEnv| { - if env.tables.iter().any(|t| t.name == table_name) { - Ok(Ok(())) - } else { - Ok(Err(format!("table {table_name} does not exist"))) - } - } - }), - }); - - let assertion = Interaction::Assertion(Assertion { - name: format!( - "row [{:?}] should be found in table {}, interactive={} commit={}, rollback={}", - row.iter().map(|v| v.to_string()).collect::>(), - insert.table(), - interactive.is_some(), - interactive - .as_ref() - .map(|i| i.end_with_commit) - .unwrap_or(false), - interactive - .as_ref() - .map(|i| !i.end_with_commit) - .unwrap_or(false), - ), - func: Box::new(move |stack: &Vec, _| { - let rows = stack.last().unwrap(); - match rows { - Ok(rows) => { - let found = rows.iter().any(|r| r == &row); - if found { - Ok(Ok(())) - } else { - Ok(Err(format!("row [{:?}] not found in table", row.iter().map(|v| v.to_string()).collect::>()))) - } - } - Err(err) => Err(LimboError::InternalError(err.to_string())), - } - }), - }); - - let mut interactions = Vec::new(); - interactions.push(assumption); - interactions.push(Interaction::Query(Query::Insert(insert.clone()))); - interactions.extend(queries.clone().into_iter().map(Interaction::Query)); - interactions.push(Interaction::Query(Query::Select(select.clone()))); - interactions.push(assertion); - - interactions - } - Property::DoubleCreateFailure { create, queries } => { - let table_name = create.table.name.clone(); - - let assumption = Interaction::Assumption(Assertion { - name: "Double-Create-Failure should not be called on an existing table" - .to_string(), - func: Box::new(move |_: &Vec, env: &mut SimulatorEnv| { - if !env.tables.iter().any(|t| t.name == table_name) { - Ok(Ok(())) - } else { - Ok(Err(format!("table {table_name} already exists"))) - } - }), - }); - - let cq1 = Interaction::Query(Query::Create(create.clone())); - let cq2 = Interaction::Query(Query::Create(create.clone())); - - let table_name = create.table.name.clone(); - - let assertion = Interaction::Assertion(Assertion { - name: - "creating two tables with the name should result in a failure for the second query" - .to_string(), - func: Box::new(move |stack: &Vec, _| { - let last = stack.last().unwrap(); - match last { - Ok(success) => Ok(Err(format!("expected table creation to fail but it succeeded: {success:?}"))), - Err(e) => { - if e.to_string().to_lowercase().contains(&format!("table {table_name} already exists")) { - Ok(Ok(())) - } else { - Ok(Err(format!("expected table already exists error, got: {e}"))) - } - } - } - }), - }); - - let mut interactions = Vec::new(); - interactions.push(assumption); - interactions.push(cq1); - interactions.extend(queries.clone().into_iter().map(Interaction::Query)); - interactions.push(cq2); - interactions.push(assertion); - - interactions - } - Property::SelectLimit { select } => { - let assumption = Interaction::Assumption(Assertion { - name: format!( - "table ({}) exists", - select - .dependencies() - .into_iter() - .collect::>() - .join(", ") - ), - func: Box::new({ - let table_name = select.dependencies(); - move |_: &Vec, env: &mut SimulatorEnv| { - if table_name - .iter() - .all(|table| env.tables.iter().any(|t| t.name == *table)) - { - Ok(Ok(())) - } else { - let missing_tables = table_name - .iter() - .filter(|t| !env.tables.iter().any(|t2| t2.name == **t)) - .collect::>(); - Ok(Err(format!("missing tables: {missing_tables:?}"))) - } - } - }), - }); - - let limit = select - .limit - .expect("Property::SelectLimit without a LIMIT clause"); - - let assertion = Interaction::Assertion(Assertion { - name: "select query should respect the limit clause".to_string(), - func: Box::new(move |stack: &Vec, _| { - let last = stack.last().unwrap(); - match last { - Ok(rows) => { - if limit >= rows.len() { - Ok(Ok(())) - } else { - Ok(Err(format!( - "limit {} violated: got {} rows", - limit, - rows.len() - ))) - } - } - Err(_) => Ok(Ok(())), - } - }), - }); - - vec![ - assumption, - Interaction::Query(Query::Select(select.clone())), - assertion, - ] - } - Property::DeleteSelect { - table, - predicate, - queries, - } => { - let assumption = Interaction::Assumption(Assertion { - name: format!("table {table} exists"), - func: Box::new({ - let table = table.clone(); - move |_: &Vec, env: &mut SimulatorEnv| { - if env.tables.iter().any(|t| t.name == table) { - Ok(Ok(())) - } else { - { - let available_tables: Vec = - env.tables.iter().map(|t| t.name.clone()).collect(); - Ok(Err(format!( - "table \'{table}\' not found. Available tables: {available_tables:?}" - ))) - } - } - } - }), - }); - - let delete = Interaction::Query(Query::Delete(Delete { - table: table.clone(), - predicate: predicate.clone(), - })); - - let select = Interaction::Query(Query::Select(Select::simple( - table.clone(), - predicate.clone(), - ))); - - let assertion = Interaction::Assertion(Assertion { - name: format!("`{select}` should return no values for table `{table}`",), - func: Box::new(move |stack: &Vec, _| { - let rows = stack.last().unwrap(); - match rows { - Ok(rows) => { - if rows.is_empty() { - Ok(Ok(())) - } else { - Ok(Err(format!( - "expected no rows but got {} rows: {:?}", - rows.len(), - rows.iter() - .map(|r| print_row(r)) - .collect::>() - .join(", ") - ))) - } - } - Err(err) => Err(LimboError::InternalError(err.to_string())), - } - }), - }); - - let mut interactions = Vec::new(); - interactions.push(assumption); - interactions.push(delete); - interactions.extend(queries.clone().into_iter().map(Interaction::Query)); - interactions.push(select); - interactions.push(assertion); - - interactions - } - Property::DropSelect { - table, - queries, - select, - } => { - let assumption = Interaction::Assumption(Assertion { - name: format!("table {table} exists"), - func: Box::new({ - let table = table.clone(); - move |_, env: &mut SimulatorEnv| { - if env.tables.iter().any(|t| t.name == table) { - Ok(Ok(())) - } else { - { - let available_tables: Vec = - env.tables.iter().map(|t| t.name.clone()).collect(); - Ok(Err(format!( - "table \'{table}\' not found. Available tables: {available_tables:?}" - ))) - } - } - } - }), - }); - - let table_name = table.clone(); - - let assertion = Interaction::Assertion(Assertion { - name: format!("select query should result in an error for table '{table}'"), - func: Box::new(move |stack: &Vec, _| { - let last = stack.last().unwrap(); - match last { - Ok(success) => Ok(Err(format!( - "expected table creation to fail but it succeeded: {success:?}" - ))), - Err(e) => { - if e.to_string() - .contains(&format!("Table {table_name} does not exist")) - { - Ok(Ok(())) - } else { - Ok(Err(format!( - "expected table does not exist error, got: {e}" - ))) - } - } - } - }), - }); - - let drop = Interaction::Query(Query::Drop(Drop { - table: table.clone(), - })); - - let select = Interaction::Query(Query::Select(select.clone())); - - let mut interactions = Vec::new(); - - interactions.push(assumption); - interactions.push(drop); - interactions.extend(queries.clone().into_iter().map(Interaction::Query)); - interactions.push(select); - interactions.push(assertion); - - interactions - } - Property::SelectSelectOptimizer { table, predicate } => { - let assumption = Interaction::Assumption(Assertion { - name: format!("table {table} exists"), - func: Box::new({ - let table = table.clone(); - move |_: &Vec, env: &mut SimulatorEnv| { - if env.tables.iter().any(|t| t.name == table) { - Ok(Ok(())) - } else { - { - let available_tables: Vec = - env.tables.iter().map(|t| t.name.clone()).collect(); - Ok(Err(format!( - "table \'{table}\' not found. Available tables: {available_tables:?}" - ))) - } - } - } - }), - }); - - let select1 = Interaction::Query(Query::Select(Select::single( - table.clone(), - vec![ResultColumn::Expr(predicate.clone())], - Predicate::true_(), - None, - Distinctness::All, - ))); - - let select2_query = Query::Select(Select::simple(table.clone(), predicate.clone())); - - let select2 = Interaction::Query(select2_query); - - let assertion = Interaction::Assertion(Assertion { - name: "select queries should return the same amount of results".to_string(), - func: Box::new(move |stack: &Vec, _| { - let select_star = stack.last().unwrap(); - let select_predicate = stack.get(stack.len() - 2).unwrap(); - match (select_predicate, select_star) { - (Ok(rows1), Ok(rows2)) => { - // If rows1 results have more than 1 column, there is a problem - if rows1.iter().any(|vs| vs.len() > 1) { - return Err(LimboError::InternalError( - "Select query without the star should return only one column".to_string(), - )); - } - // Count the 1s in the select query without the star - let rows1_count = rows1 - .iter() - .filter(|vs| { - let v = vs.first().unwrap(); - v.as_bool() - }) - .count(); - tracing::debug!( - "select1 returned {} rows, select2 returned {} rows", - rows1_count, - rows2.len() - ); - if rows1_count == rows2.len() { - Ok(Ok(())) - } else { - Ok(Err(format!( - "row counts don't match: {} vs {}", - rows1_count, - rows2.len() - ))) - } - } - (Err(e1), Err(e2)) => { - tracing::debug!("Error in select1 AND select2: {}, {}", e1, e2); - Ok(Ok(())) - } - (Err(e), _) | (_, Err(e)) => { - tracing::error!("Error in select1 OR select2: {}", e); - Err(LimboError::InternalError(e.to_string())) - } - } - }), - }); - - vec![assumption, select1, select2, assertion] - } - Property::FsyncNoWait { query, tables } => { - let checks = assert_all_table_values(tables); - Vec::from_iter( - std::iter::once(Interaction::FsyncQuery(query.clone())).chain(checks), - ) - } - Property::FaultyQuery { query, tables } => { - let checks = assert_all_table_values(tables); - let query_clone = query.clone(); - let assert = Assertion { - // A fault may not occur as we first signal we want a fault injected, - // then when IO is called the fault triggers. It may happen that a fault is injected - // but no IO happens right after it - name: "fault occured".to_string(), - func: Box::new(move |stack, env: &mut SimulatorEnv| { - let last = stack.last().unwrap(); - match last { - Ok(_) => { - let _ = query_clone.shadow(&mut env.tables); - Ok(Ok(())) - } - Err(err) => { - // We cannot make any assumptions about the error content; all we are about is, if the statement errored, - // we don't shadow the results into the simulator env, i.e. we assume whatever the statement did was rolled back. - tracing::error!("Fault injection produced error: {err}"); - Ok(Ok(())) - } - } - }), - }; - let first = [ - Interaction::FaultyQuery(query.clone()), - Interaction::Assertion(assert), - ] - .into_iter(); - Vec::from_iter(first.chain(checks)) - } - Property::WhereTrueFalseNull { select, predicate } => { - let assumption = Interaction::Assumption(Assertion { - name: format!( - "tables ({}) exists", - select - .dependencies() - .into_iter() - .collect::>() - .join(", ") - ), - func: Box::new({ - let tables = select.dependencies(); - move |_: &Vec, env: &mut SimulatorEnv| { - if tables - .iter() - .all(|table| env.tables.iter().any(|t| t.name == *table)) - { - Ok(Ok(())) - } else { - let missing_tables = tables - .iter() - .filter(|t| !env.tables.iter().any(|t2| t2.name == **t)) - .collect::>(); - Ok(Err(format!("missing tables: {missing_tables:?}"))) - } - } - }), - }); - - let old_predicate = select.body.select.where_clause.clone(); - - let p_true = Predicate::and(vec![old_predicate.clone(), predicate.clone()]); - let p_false = Predicate::and(vec![ - old_predicate.clone(), - Predicate::not(predicate.clone()), - ]); - let p_null = Predicate::and(vec![ - old_predicate.clone(), - Predicate::is(predicate.clone(), Predicate::null()), - ]); - - let select_tlp = Select { - body: SelectBody { - select: Box::new(SelectInner { - distinctness: select.body.select.distinctness, - columns: select.body.select.columns.clone(), - from: select.body.select.from.clone(), - where_clause: p_true, - order_by: None, - }), - compounds: vec![ - CompoundSelect { - operator: CompoundOperator::UnionAll, - select: Box::new(SelectInner { - distinctness: select.body.select.distinctness, - columns: select.body.select.columns.clone(), - from: select.body.select.from.clone(), - where_clause: p_false, - order_by: None, - }), - }, - CompoundSelect { - operator: CompoundOperator::UnionAll, - select: Box::new(SelectInner { - distinctness: select.body.select.distinctness, - columns: select.body.select.columns.clone(), - from: select.body.select.from.clone(), - where_clause: p_null, - order_by: None, - }), - }, - ], - }, - limit: None, - }; - - let select = Interaction::Query(Query::Select(select.clone())); - let select_tlp = Interaction::Query(Query::Select(select_tlp)); - - // select and select_tlp should return the same rows - let assertion = Interaction::Assertion(Assertion { - name: "select and select_tlp should return the same rows".to_string(), - func: Box::new(move |stack: &Vec, _: &mut SimulatorEnv| { - if stack.len() < 2 { - return Err(LimboError::InternalError( - "Not enough result sets on the stack".to_string(), - )); - } - - let select_result_set = stack.get(stack.len() - 2).unwrap(); - let select_tlp_result_set = stack.last().unwrap(); - - match (select_result_set, select_tlp_result_set) { - (Ok(select_rows), Ok(select_tlp_rows)) => { - if select_rows.len() != select_tlp_rows.len() { - return Ok(Err(format!("row count mismatch: select returned {} rows, select_tlp returned {} rows", select_rows.len(), select_tlp_rows.len()))); - } - // Check if any row in select_rows is not in select_tlp_rows - for row in select_rows.iter() { - if !select_tlp_rows.iter().any(|r| r == row) { - tracing::debug!( - "select and select_tlp returned different rows, ({}) is in select but not in select_tlp", - row.iter().map(|v| v.to_string()).collect::>().join(", ") - ); - return Ok(Err(format!( - "row mismatch: row [{}] exists in select results but not in select_tlp results", - print_row(row) - ))); - } - } - // Check if any row in select_tlp_rows is not in select_rows - for row in select_tlp_rows.iter() { - if !select_rows.iter().any(|r| r == row) { - tracing::debug!( - "select and select_tlp returned different rows, ({}) is in select_tlp but not in select", - row.iter().map(|v| v.to_string()).collect::>().join(", ") - ); - - return Ok(Err(format!( - "row mismatch: row [{}] exists in select_tlp but not in select", - print_row(row) - ))); - } - } - // If we reach here, the rows are the same - tracing::trace!( - "select and select_tlp returned the same rows: {:?}", - select_rows - ); - - Ok(Ok(())) - } - (Err(e), _) | (_, Err(e)) => { - tracing::error!("Error in select or select_tlp: {}", e); - Err(LimboError::InternalError(e.to_string())) - } - } - }), - }); - - vec![assumption, select, select_tlp, assertion] - } - Property::UNIONAllPreservesCardinality { - select, - where_clause, - } => { - let s1 = select.clone(); - let mut s2 = select.clone(); - s2.body.select.where_clause = where_clause.clone(); - let s3 = Select::compound(s1.clone(), s2.clone(), CompoundOperator::UnionAll); - - vec![ - Interaction::Query(Query::Select(s1.clone())), - Interaction::Query(Query::Select(s2.clone())), - Interaction::Query(Query::Select(s3.clone())), - Interaction::Assertion(Assertion { - name: "UNION ALL should preserve cardinality".to_string(), - func: Box::new(move |stack: &Vec, _: &mut SimulatorEnv| { - if stack.len() < 3 { - return Err(LimboError::InternalError( - "Not enough result sets on the stack".to_string(), - )); - } - - let select1 = stack.get(stack.len() - 3).unwrap(); - let select2 = stack.get(stack.len() - 2).unwrap(); - let union_all = stack.last().unwrap(); - - match (select1, select2, union_all) { - (Ok(rows1), Ok(rows2), Ok(union_rows)) => { - let count1 = rows1.len(); - let count2 = rows2.len(); - let union_count = union_rows.len(); - if union_count == count1 + count2 { - Ok(Ok(())) - } else { - Ok(Err(format!("UNION ALL should preserve cardinality but it didn't: {count1} + {count2} != {union_count}"))) - } - } - (Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => { - tracing::error!("Error in select queries: {}", e); - Err(LimboError::InternalError(e.to_string())) - } - } - }), - }), - ] - } - } - } -} - -fn assert_all_table_values(tables: &[String]) -> impl Iterator + use<'_> { - let checks = tables.iter().flat_map(|table| { - let select = Interaction::Query(Query::Select(Select::simple( - table.clone(), - Predicate::true_(), - ))); - - let assertion = Interaction::Assertion(Assertion { - name: format!("table {table} should contain all of its expected values"), - func: Box::new({ - let table = table.clone(); - move |stack: &Vec, env: &mut SimulatorEnv| { - let table = env.tables.iter().find(|t| t.name == table).ok_or_else(|| { - LimboError::InternalError(format!( - "table {table} should exist in simulator env" - )) - })?; - let last = stack.last().unwrap(); - match last { - Ok(vals) => { - // Check if all values in the table are present in the result set - // Find a value in the table that is not in the result set - let model_contains_db = table.rows.iter().find(|v| { - !vals.iter().any(|r| { - &r == v - }) - }); - let db_contains_model = vals.iter().find(|v| { - !table.rows.iter().any(|r| &r == v) - }); - - if let Some(model_contains_db) = model_contains_db { - tracing::debug!( - "table {} does not contain the expected values, the simulator model has more rows than the database: {:?}", - table.name, - print_row(model_contains_db) - ); - Ok(Err(format!("table {} does not contain the expected values, the simulator model has more rows than the database: {:?}", table.name, print_row(model_contains_db)))) - } else if let Some(db_contains_model) = db_contains_model { - tracing::debug!( - "table {} does not contain the expected values, the database has more rows than the simulator model: {:?}", - table.name, - print_row(db_contains_model) - ); - Ok(Err(format!("table {} does not contain the expected values, the database has more rows than the simulator model: {:?}", table.name, print_row(db_contains_model)))) - } else { - Ok(Ok(())) - } - } - Err(err) => Err(LimboError::InternalError(format!("{err}"))), - } - } - }), - }); - [select, assertion].into_iter() - }); - checks -} - -#[derive(Debug)] -pub(crate) struct Remaining { - pub(crate) read: f64, - pub(crate) write: f64, - pub(crate) create: f64, - pub(crate) create_index: f64, - pub(crate) delete: f64, - pub(crate) update: f64, - pub(crate) drop: f64, -} - -pub(crate) fn remaining(env: &SimulatorEnv, stats: &InteractionStats) -> Remaining { - let remaining_read = ((env.opts.max_interactions as f64 * env.opts.read_percent / 100.0) - - (stats.read_count as f64)) - .max(0.0); - let remaining_write = ((env.opts.max_interactions as f64 * env.opts.write_percent / 100.0) - - (stats.write_count as f64)) - .max(0.0); - let remaining_create = ((env.opts.max_interactions as f64 * env.opts.create_percent / 100.0) - - (stats.create_count as f64)) - .max(0.0); - - let remaining_create_index = - ((env.opts.max_interactions as f64 * env.opts.create_index_percent / 100.0) - - (stats.create_index_count as f64)) - .max(0.0); - - let remaining_delete = ((env.opts.max_interactions as f64 * env.opts.delete_percent / 100.0) - - (stats.delete_count as f64)) - .max(0.0); - let remaining_update = ((env.opts.max_interactions as f64 * env.opts.update_percent / 100.0) - - (stats.update_count as f64)) - .max(0.0); - let remaining_drop = ((env.opts.max_interactions as f64 * env.opts.drop_percent / 100.0) - - (stats.drop_count as f64)) - .max(0.0); - - Remaining { - read: remaining_read, - write: remaining_write, - create: remaining_create, - create_index: remaining_create_index, - delete: remaining_delete, - drop: remaining_drop, - update: remaining_update, - } -} - -fn property_insert_values_select( - rng: &mut R, - env: &SimulatorEnv, - remaining: &Remaining, -) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Generate rows to insert - let rows = (0..rng.random_range(1..=5)) - .map(|_| Vec::::arbitrary_from(rng, table)) - .collect::>(); - - // Pick a random row to select - let row_index = pick_index(rows.len(), rng); - let row = rows[row_index].clone(); - - // Insert the rows - let insert_query = Insert::Values { - table: table.name.clone(), - values: rows, - }; - - // Choose if we want queries to be executed in an interactive transaction - let interactive = if rng.random_bool(0.5) { - Some(InteractiveQueryInfo { - start_with_immediate: rng.random_bool(0.5), - end_with_commit: rng.random_bool(0.5), - }) - } else { - None - }; - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) - // - [x] The inserted row will not be deleted. - // - [x] The inserted row will not be updated. - // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) - if let Some(ref interactive) = interactive { - queries.push(Query::Begin(Begin { - immediate: interactive.start_with_immediate, - })); - } - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, (env, remaining)); - match &query { - Query::Delete(Delete { - table: t, - predicate, - }) => { - // The inserted row will not be deleted. - if t == &table.name && predicate.test(&row, table) { - continue; - } - } - Query::Create(Create { table: t }) => { - // There will be no errors in the middle interactions. - // - Creating the same table is an error - if t.name == table.name { - continue; - } - } - Query::Update(Update { - table: t, - set_values: _, - predicate, - }) => { - // The inserted row will not be updated. - if t == &table.name && predicate.test(&row, table) { - continue; - } - } - _ => (), - } - queries.push(query); - } - if let Some(ref interactive) = interactive { - queries.push(if interactive.end_with_commit { - Query::Commit(Commit) - } else { - Query::Rollback(Rollback) - }); - } - - // Select the row - let select_query = Select::simple( - table.name.clone(), - Predicate::arbitrary_from(rng, (table, &row)), - ); - - Property::InsertValuesSelect { - insert: insert_query, - row_index, - queries, - select: select_query, - interactive, - } -} - -fn property_read_your_updates_back(rng: &mut R, env: &SimulatorEnv) -> Property { - // e.g. UPDATE t SET a=1, b=2 WHERE c=1; - let update = Update::arbitrary_from(rng, env); - // e.g. SELECT a, b FROM t WHERE c=1; - let select = Select::single( - update.table().to_string(), - update - .set_values - .iter() - .map(|(col, _)| ResultColumn::Column(col.clone())) - .collect(), - update.predicate.clone(), - None, - Distinctness::All, - ); - - Property::ReadYourUpdatesBack { update, select } -} - -fn property_table_has_expected_content(rng: &mut R, env: &SimulatorEnv) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - Property::TableHasExpectedContent { - table: table.name.clone(), - } -} - -fn property_select_limit(rng: &mut R, env: &SimulatorEnv) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Select the table - let select = Select::single( - table.name.clone(), - vec![ResultColumn::Star], - Predicate::arbitrary_from(rng, table), - Some(rng.random_range(1..=5)), - Distinctness::All, - ); - Property::SelectLimit { select } -} - -fn property_double_create_failure( - rng: &mut R, - env: &SimulatorEnv, - remaining: &Remaining, -) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Create the table - let create_query = Create { - table: table.clone(), - }; - - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // The interactions in the middle has the following constraints; - // - [x] There will be no errors in the middle interactions.(best effort) - // - [ ] Table `t` will not be renamed or dropped.(todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, (env, remaining)); - if let Query::Create(Create { table: t }) = &query { - // There will be no errors in the middle interactions. - // - Creating the same table is an error - if t.name == table.name { - continue; - } - } - queries.push(query); - } - - Property::DoubleCreateFailure { - create: create_query, - queries, - } -} - -fn property_delete_select( - rng: &mut R, - env: &SimulatorEnv, - remaining: &Remaining, -) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Generate a random predicate - let predicate = Predicate::arbitrary_from(rng, table); - - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) - // - [x] A row that holds for the predicate will not be inserted. - // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, (env, remaining)); - match &query { - Query::Insert(Insert::Values { table: t, values }) => { - // A row that holds for the predicate will not be inserted. - if t == &table.name && values.iter().any(|v| predicate.test(v, table)) { - continue; - } - } - Query::Insert(Insert::Select { - table: t, - select: _, - }) => { - // A row that holds for the predicate will not be inserted. - if t == &table.name { - continue; - } - } - Query::Update(Update { table: t, .. }) => { - // A row that holds for the predicate will not be updated. - if t == &table.name { - continue; - } - } - Query::Create(Create { table: t }) => { - // There will be no errors in the middle interactions. - // - Creating the same table is an error - if t.name == table.name { - continue; - } - } - _ => (), - } - queries.push(query); - } - - Property::DeleteSelect { - table: table.name.clone(), - predicate, - queries, - } -} - -fn property_drop_select( - rng: &mut R, - env: &SimulatorEnv, - remaining: &Remaining, -) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - - // Create random queries respecting the constraints - let mut queries = Vec::new(); - // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) - // - [-] The table `t` will not be created, no table will be renamed to `t`. (todo: update this constraint once ALTER is implemented) - for _ in 0..rng.random_range(0..3) { - let query = Query::arbitrary_from(rng, (env, remaining)); - if let Query::Create(Create { table: t }) = &query { - // - The table `t` will not be created - if t.name == table.name { - continue; - } - } - queries.push(query); - } - - let select = Select::simple(table.name.clone(), Predicate::arbitrary_from(rng, table)); - - Property::DropSelect { - table: table.name.clone(), - queries, - select, - } -} - -fn property_select_select_optimizer(rng: &mut R, env: &SimulatorEnv) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Generate a random predicate - let predicate = Predicate::arbitrary_from(rng, table); - // Transform into a Binary predicate to force values to be casted to a bool - let expr = ast::Expr::Binary( - Box::new(predicate.0), - ast::Operator::And, - Box::new(Predicate::true_().0), - ); - - Property::SelectSelectOptimizer { - table: table.name.clone(), - predicate: Predicate(expr), - } -} - -fn property_where_true_false_null(rng: &mut R, env: &SimulatorEnv) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Generate a random predicate - let p1 = Predicate::arbitrary_from(rng, table); - let p2 = Predicate::arbitrary_from(rng, table); - - // Create the select query - let select = Select::simple(table.name.clone(), p1); - - Property::WhereTrueFalseNull { - select, - predicate: p2, - } -} - -fn property_union_all_preserves_cardinality( - rng: &mut R, - env: &SimulatorEnv, -) -> Property { - // Get a random table - let table = pick(&env.tables, rng); - // Generate a random predicate - let p1 = Predicate::arbitrary_from(rng, table); - let p2 = Predicate::arbitrary_from(rng, table); - - // Create the select query - let select = Select::single( - table.name.clone(), - vec![ResultColumn::Star], - p1, - None, - Distinctness::All, - ); - - Property::UNIONAllPreservesCardinality { - select, - where_clause: p2, - } -} - -fn property_fsync_no_wait( - rng: &mut R, - env: &SimulatorEnv, - remaining: &Remaining, -) -> Property { - Property::FsyncNoWait { - query: Query::arbitrary_from(rng, (env, remaining)), - tables: env.tables.iter().map(|t| t.name.clone()).collect(), - } -} - -fn property_faulty_query( - rng: &mut R, - env: &SimulatorEnv, - remaining: &Remaining, -) -> Property { - Property::FaultyQuery { - query: Query::arbitrary_from(rng, (env, remaining)), - tables: env.tables.iter().map(|t| t.name.clone()).collect(), - } -} - -impl ArbitraryFrom<(&SimulatorEnv, &InteractionStats)> for Property { - fn arbitrary_from( - rng: &mut R, - (env, stats): (&SimulatorEnv, &InteractionStats), - ) -> Self { - let remaining_ = remaining(env, stats); - - frequency( - vec![ - ( - if !env.opts.disable_insert_values_select { - f64::min(remaining_.read, remaining_.write) - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_insert_values_select(rng, env, &remaining_)), - ), - ( - remaining_.read, - Box::new(|rng: &mut R| property_table_has_expected_content(rng, env)), - ), - ( - f64::min(remaining_.read, remaining_.write), - Box::new(|rng: &mut R| property_read_your_updates_back(rng, env)), - ), - ( - if !env.opts.disable_double_create_failure { - remaining_.create / 2.0 - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_double_create_failure(rng, env, &remaining_)), - ), - ( - if !env.opts.disable_select_limit { - remaining_.read - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_select_limit(rng, env)), - ), - ( - if !env.opts.disable_delete_select { - f64::min(remaining_.read, remaining_.write).min(remaining_.delete) - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_delete_select(rng, env, &remaining_)), - ), - ( - if !env.opts.disable_drop_select { - // remaining_.drop - 0.0 - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_drop_select(rng, env, &remaining_)), - ), - ( - if !env.opts.disable_select_optimizer { - remaining_.read / 2.0 - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_select_select_optimizer(rng, env)), - ), - ( - if env.opts.experimental_indexes && !env.opts.disable_where_true_false_null { - remaining_.read / 2.0 - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_where_true_false_null(rng, env)), - ), - ( - if env.opts.experimental_indexes - && !env.opts.disable_union_all_preserves_cardinality - { - remaining_.read / 3.0 - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_union_all_preserves_cardinality(rng, env)), - ), - ( - if !env.opts.disable_fsync_no_wait { - 50.0 // Freestyle number - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_fsync_no_wait(rng, env, &remaining_)), - ), - ( - if !env.opts.disable_faulty_query { - 20.0 - } else { - 0.0 - }, - Box::new(|rng: &mut R| property_faulty_query(rng, env, &remaining_)), - ), - ], - rng, - ) - } -} - -fn print_row(row: &[SimValue]) -> String { - row.iter() - .map(|v| match &v.0 { - types::Value::Null => "NULL".to_string(), - types::Value::Integer(i) => i.to_string(), - types::Value::Float(f) => f.to_string(), - types::Value::Text(t) => t.to_string(), - types::Value::Blob(b) => format!( - "X'{}'", - b.iter() - .fold(String::new(), |acc, b| acc + &format!("{b:02X}")) - ), - }) - .collect::>() - .join(", ") -} diff --git a/sql_generation/generation/query.rs b/sql_generation/generation/query.rs index eff24613c..d7840a001 100644 --- a/sql_generation/generation/query.rs +++ b/sql_generation/generation/query.rs @@ -1,5 +1,6 @@ use crate::generation::{ - gen_random_text, pick_n_unique, Arbitrary, ArbitraryFrom, ArbitrarySizedFrom, + gen_random_text, pick_n_unique, pick_unique, Arbitrary, ArbitraryFrom, ArbitrarySizedFrom, + GenerationContext, }; use crate::model::query::predicate::Predicate; use crate::model::query::select::{ @@ -7,15 +8,13 @@ use crate::model::query::select::{ SelectInner, }; use crate::model::query::update::Update; -use crate::model::query::{Create, CreateIndex, Delete, Drop, Insert, Query, Select}; +use crate::model::query::{Create, CreateIndex, Delete, Drop, Insert, Select}; use crate::model::table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}; -use crate::SimulatorEnv; use itertools::Itertools; use rand::Rng; use turso_parser::ast::{Expr, SortOrder}; -use super::property::Remaining; -use super::{backtrack, frequency, pick}; +use super::{backtrack, pick}; impl Arbitrary for Create { fn arbitrary(rng: &mut R) -> Self { @@ -87,14 +86,11 @@ impl ArbitraryFrom<&Vec
> for FromClause { } } -impl ArbitraryFrom<&SimulatorEnv> for SelectInner { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let from = FromClause::arbitrary_from(rng, &env.tables); - let mut tables = env.tables.clone(); - // todo: this is a temporary hack because env is not separated from the tables - let join_table = from - .shadow(&mut tables) - .expect("Failed to shadow FromClause"); +impl ArbitraryFrom<&C> for SelectInner { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let from = FromClause::arbitrary_from(rng, env.tables()); + let tables = env.tables().clone(); + let join_table = from.into_join_table(&tables); let cuml_col_count = join_table.columns().count(); let order_by = 'order_by: { @@ -137,7 +133,7 @@ impl ArbitraryFrom<&SimulatorEnv> for SelectInner { }; SelectInner { - distinctness: if env.opts.experimental_indexes { + distinctness: if env.opts().indexes { Distinctness::arbitrary(rng) } else { Distinctness::All @@ -150,12 +146,8 @@ impl ArbitraryFrom<&SimulatorEnv> for SelectInner { } } -impl ArbitrarySizedFrom<&SimulatorEnv> for SelectInner { - fn arbitrary_sized_from( - rng: &mut R, - env: &SimulatorEnv, - num_result_columns: usize, - ) -> Self { +impl ArbitrarySizedFrom<&C> for SelectInner { + fn arbitrary_sized_from(rng: &mut R, env: &C, num_result_columns: usize) -> Self { let mut select_inner = SelectInner::arbitrary_from(rng, env); let select_from = &select_inner.from.as_ref().unwrap(); let table_names = select_from @@ -168,7 +160,7 @@ impl ArbitrarySizedFrom<&SimulatorEnv> for SelectInner { let flat_columns_names = table_names .iter() .flat_map(|t| { - env.tables + env.tables() .iter() .find(|table| table.name == *t) .unwrap() @@ -208,22 +200,22 @@ impl Arbitrary for CompoundOperator { /// SelectFree is a wrapper around Select that allows for arbitrary generation /// of selects without requiring a specific environment, which is useful for generating /// arbitrary expressions without referring to the tables. -pub(crate) struct SelectFree(pub(crate) Select); +pub struct SelectFree(pub Select); -impl ArbitraryFrom<&SimulatorEnv> for SelectFree { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for SelectFree { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { let expr = Predicate(Expr::arbitrary_sized_from(rng, env, 8)); let select = Select::expr(expr); Self(select) } } -impl ArbitraryFrom<&SimulatorEnv> for Select { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for Select { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { // Generate a number of selects based on the query size // If experimental indexes are enabled, we can have selects with compounds // Otherwise, we just have a single select with no compounds - let num_compound_selects = if env.opts.experimental_indexes { + let num_compound_selects = if env.opts().indexes { match rng.random_range(0..=100) { 0..=95 => 0, 96..=99 => 1, @@ -235,7 +227,7 @@ impl ArbitraryFrom<&SimulatorEnv> for Select { }; let min_column_count_across_tables = - env.tables.iter().map(|t| t.columns.len()).min().unwrap(); + env.tables().iter().map(|t| t.columns.len()).min().unwrap(); let num_result_columns = rng.random_range(1..=min_column_count_across_tables); @@ -269,10 +261,10 @@ impl ArbitraryFrom<&SimulatorEnv> for Select { } } -impl ArbitraryFrom<&SimulatorEnv> for Insert { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for Insert { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { let gen_values = |rng: &mut R| { - let table = pick(&env.tables, rng); + let table = pick(env.tables(), rng); let num_rows = rng.random_range(1..10); let values: Vec> = (0..num_rows) .map(|_| { @@ -291,12 +283,12 @@ impl ArbitraryFrom<&SimulatorEnv> for Insert { let _gen_select = |rng: &mut R| { // Find a non-empty table - let select_table = env.tables.iter().find(|t| !t.rows.is_empty())?; + let select_table = env.tables().iter().find(|t| !t.rows.is_empty())?; let row = pick(&select_table.rows, rng); let predicate = Predicate::arbitrary_from(rng, (select_table, row)); // Pick another table to insert into let select = Select::simple(select_table.name.clone(), predicate); - let table = pick(&env.tables, rng); + let table = pick(env.tables(), rng); Some(Insert::Select { table: table.name.clone(), select: Box::new(select), @@ -309,9 +301,9 @@ impl ArbitraryFrom<&SimulatorEnv> for Insert { } } -impl ArbitraryFrom<&SimulatorEnv> for Delete { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); +impl ArbitraryFrom<&C> for Delete { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let table = pick(env.tables(), rng); Self { table: table.name.clone(), predicate: Predicate::arbitrary_from(rng, table), @@ -319,23 +311,23 @@ impl ArbitraryFrom<&SimulatorEnv> for Delete { } } -impl ArbitraryFrom<&SimulatorEnv> for Drop { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); +impl ArbitraryFrom<&C> for Drop { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let table = pick(env.tables(), rng); Self { table: table.name.clone(), } } } -impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { +impl ArbitraryFrom<&C> for CreateIndex { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { assert!( - !env.tables.is_empty(), + !env.tables().is_empty(), "Cannot create an index when no tables exist in the environment." ); - let table = pick(&env.tables, rng); + let table = pick(env.tables(), rng); if table.columns.is_empty() { panic!( @@ -376,57 +368,9 @@ impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { } } -impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query { - fn arbitrary_from(rng: &mut R, (env, remaining): (&SimulatorEnv, &Remaining)) -> Self { - frequency( - vec![ - ( - remaining.create, - Box::new(|rng| Self::Create(Create::arbitrary(rng))), - ), - ( - remaining.read, - Box::new(|rng| Self::Select(Select::arbitrary_from(rng, env))), - ), - ( - remaining.write, - Box::new(|rng| Self::Insert(Insert::arbitrary_from(rng, env))), - ), - ( - remaining.update, - Box::new(|rng| Self::Update(Update::arbitrary_from(rng, env))), - ), - ( - f64::min(remaining.write, remaining.delete), - Box::new(|rng| Self::Delete(Delete::arbitrary_from(rng, env))), - ), - ], - rng, - ) - } -} - -fn pick_unique( - items: &[T], - count: usize, - rng: &mut impl rand::Rng, -) -> Vec -where - ::Owned: PartialEq, -{ - let mut picked: Vec = Vec::new(); - while picked.len() < count { - let item = pick(items, rng); - if !picked.contains(&item.to_owned()) { - picked.push(item.to_owned()); - } - } - picked -} - -impl ArbitraryFrom<&SimulatorEnv> for Update { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); +impl ArbitraryFrom<&C> for Update { + fn arbitrary_from(rng: &mut R, env: &C) -> Self { + let table = pick(env.tables(), rng); let num_cols = rng.random_range(1..=table.columns.len()); let columns = pick_unique(&table.columns, num_cols, rng); let set_values: Vec<(String, SimValue)> = columns diff --git a/sql_generation/generation/table.rs b/sql_generation/generation/table.rs index fdddb6ff2..d21397cbe 100644 --- a/sql_generation/generation/table.rs +++ b/sql_generation/generation/table.rs @@ -99,7 +99,7 @@ impl ArbitraryFrom<&ColumnType> for SimValue { } } -pub(crate) struct LTValue(pub(crate) SimValue); +pub struct LTValue(pub SimValue); impl ArbitraryFrom<&Vec<&SimValue>> for LTValue { fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { @@ -161,7 +161,7 @@ impl ArbitraryFrom<&SimValue> for LTValue { } } -pub(crate) struct GTValue(pub(crate) SimValue); +pub struct GTValue(pub SimValue); impl ArbitraryFrom<&Vec<&SimValue>> for GTValue { fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { @@ -223,7 +223,7 @@ impl ArbitraryFrom<&SimValue> for GTValue { } } -pub(crate) struct LikeValue(pub(crate) SimValue); +pub struct LikeValue(pub SimValue); impl ArbitraryFromMaybe<&SimValue> for LikeValue { fn arbitrary_from_maybe(rng: &mut R, value: &SimValue) -> Option { diff --git a/sql_generation/model/mod.rs b/sql_generation/model/mod.rs index e68355ee4..a29f56382 100644 --- a/sql_generation/model/mod.rs +++ b/sql_generation/model/mod.rs @@ -1,4 +1,2 @@ pub mod query; pub mod table; - -pub(crate) const FAULT_ERROR_MSG: &str = "Injected fault"; diff --git a/sql_generation/model/query/create.rs b/sql_generation/model/query/create.rs index e628b5cc5..607d5fe8d 100644 --- a/sql_generation/model/query/create.rs +++ b/sql_generation/model/query/create.rs @@ -5,8 +5,8 @@ use serde::{Deserialize, Serialize}; use crate::model::table::Table; #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Create { - pub(crate) table: Table, +pub struct Create { + pub table: Table, } impl Display for Create { diff --git a/sql_generation/model/query/create_index.rs b/sql_generation/model/query/create_index.rs index aba0f98bf..db9d15a04 100644 --- a/sql_generation/model/query/create_index.rs +++ b/sql_generation/model/query/create_index.rs @@ -1,25 +1,11 @@ use serde::{Deserialize, Serialize}; +use turso_parser::ast::SortOrder; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum SortOrder { - Asc, - Desc, -} - -impl std::fmt::Display for SortOrder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SortOrder::Asc => write!(f, "ASC"), - SortOrder::Desc => write!(f, "DESC"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(crate) struct CreateIndex { - pub(crate) index_name: String, - pub(crate) table_name: String, - pub(crate) columns: Vec<(String, SortOrder)>, +pub struct CreateIndex { + pub index_name: String, + pub table_name: String, + pub columns: Vec<(String, SortOrder)>, } impl std::fmt::Display for CreateIndex { diff --git a/sql_generation/model/query/delete.rs b/sql_generation/model/query/delete.rs index a86479850..89ebd61b8 100644 --- a/sql_generation/model/query/delete.rs +++ b/sql_generation/model/query/delete.rs @@ -5,9 +5,9 @@ use serde::{Deserialize, Serialize}; use super::predicate::Predicate; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Delete { - pub(crate) table: String, - pub(crate) predicate: Predicate, +pub struct Delete { + pub table: String, + pub predicate: Predicate, } impl Display for Delete { diff --git a/sql_generation/model/query/drop.rs b/sql_generation/model/query/drop.rs index d9a34a9e9..0d0ef31bb 100644 --- a/sql_generation/model/query/drop.rs +++ b/sql_generation/model/query/drop.rs @@ -3,8 +3,8 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Drop { - pub(crate) table: String, +pub struct Drop { + pub table: String, } impl Display for Drop { diff --git a/sql_generation/model/query/insert.rs b/sql_generation/model/query/insert.rs index 9fd391612..d69921388 100644 --- a/sql_generation/model/query/insert.rs +++ b/sql_generation/model/query/insert.rs @@ -7,7 +7,7 @@ use crate::model::table::SimValue; use super::select::Select; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) enum Insert { +pub enum Insert { Values { table: String, values: Vec>, @@ -19,7 +19,7 @@ pub(crate) enum Insert { } impl Insert { - pub(crate) fn table(&self) -> &str { + pub fn table(&self) -> &str { match self { Insert::Values { table, .. } | Insert::Select { table, .. } => table, } diff --git a/sql_generation/model/query/mod.rs b/sql_generation/model/query/mod.rs index 9ae222a9a..016ebec29 100644 --- a/sql_generation/model/query/mod.rs +++ b/sql_generation/model/query/mod.rs @@ -1,11 +1,11 @@ use std::{collections::HashSet, fmt::Display}; -pub(crate) use create::Create; -pub(crate) use create_index::CreateIndex; -pub(crate) use delete::Delete; -pub(crate) use drop::Drop; -pub(crate) use insert::Insert; -pub(crate) use select::Select; +pub use create::Create; +pub use create_index::CreateIndex; +pub use delete::Delete; +pub use drop::Drop; +pub use insert::Insert; +pub use select::Select; use serde::{Deserialize, Serialize}; use turso_parser::ast::fmt::ToSqlContext; use update::Update; @@ -24,7 +24,7 @@ pub mod update; // This type represents the potential queries on the database. #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum Query { +pub enum Query { Create(Create), Select(Select), Insert(Insert), @@ -38,7 +38,7 @@ pub(crate) enum Query { } impl Query { - pub(crate) fn dependencies(&self) -> HashSet { + pub fn dependencies(&self) -> HashSet { match self { Query::Select(select) => select.dependencies(), Query::Create(_) => HashSet::new(), @@ -53,7 +53,7 @@ impl Query { Query::Begin(_) | Query::Commit(_) | Query::Rollback(_) => HashSet::new(), } } - pub(crate) fn uses(&self) -> Vec { + pub fn uses(&self) -> Vec { match self { Query::Create(Create { table }) => vec![table.name.clone()], Query::Select(select) => select.dependencies().into_iter().collect(), @@ -86,7 +86,7 @@ impl Display for Query { } /// Used to print sql strings that already have all the context it needs -pub(crate) struct EmptyContext; +pub struct EmptyContext; impl ToSqlContext for EmptyContext { fn get_column_name( diff --git a/sql_generation/model/query/predicate.rs b/sql_generation/model/query/predicate.rs index bb4bf0bf7..30b671d72 100644 --- a/sql_generation/model/query/predicate.rs +++ b/sql_generation/model/query/predicate.rs @@ -9,27 +9,28 @@ use crate::model::table::{SimValue, Table, TableContext}; pub struct Predicate(pub ast::Expr); impl Predicate { - pub(crate) fn true_() -> Self { + pub fn true_() -> Self { Self(ast::Expr::Literal(ast::Literal::Keyword( "TRUE".to_string(), ))) } - pub(crate) fn false_() -> Self { + pub fn false_() -> Self { Self(ast::Expr::Literal(ast::Literal::Keyword( "FALSE".to_string(), ))) } - pub(crate) fn null() -> Self { + pub fn null() -> Self { Self(ast::Expr::Literal(ast::Literal::Null)) } - pub(crate) fn not(predicate: Predicate) -> Self { + #[allow(clippy::should_implement_trait)] + pub fn not(predicate: Predicate) -> Self { let expr = ast::Expr::Unary(ast::UnaryOperator::Not, Box::new(predicate.0)); Self(expr).parens() } - pub(crate) fn and(predicates: Vec) -> Self { + pub fn and(predicates: Vec) -> Self { if predicates.is_empty() { Self::true_() } else if predicates.len() == 1 { @@ -44,7 +45,7 @@ impl Predicate { } } - pub(crate) fn or(predicates: Vec) -> Self { + pub fn or(predicates: Vec) -> Self { if predicates.is_empty() { Self::false_() } else if predicates.len() == 1 { @@ -59,26 +60,26 @@ impl Predicate { } } - pub(crate) fn eq(lhs: Predicate, rhs: Predicate) -> Self { + pub fn eq(lhs: Predicate, rhs: Predicate) -> Self { let expr = ast::Expr::Binary(Box::new(lhs.0), ast::Operator::Equals, Box::new(rhs.0)); Self(expr).parens() } - pub(crate) fn is(lhs: Predicate, rhs: Predicate) -> Self { + pub fn is(lhs: Predicate, rhs: Predicate) -> Self { let expr = ast::Expr::Binary(Box::new(lhs.0), ast::Operator::Is, Box::new(rhs.0)); Self(expr).parens() } - pub(crate) fn parens(self) -> Self { + pub fn parens(self) -> Self { let expr = ast::Expr::Parenthesized(vec![self.0]); Self(expr) } - pub(crate) fn eval(&self, row: &[SimValue], table: &Table) -> Option { + pub fn eval(&self, row: &[SimValue], table: &Table) -> Option { expr_to_value(&self.0, row, table) } - pub(crate) fn test(&self, row: &[SimValue], table: &T) -> bool { + pub fn test(&self, row: &[SimValue], table: &T) -> bool { let value = expr_to_value(&self.0, row, table); value.is_some_and(|value| value.as_bool()) } diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs index 6c7897e1d..6c34888ff 100644 --- a/sql_generation/model/query/select.rs +++ b/sql_generation/model/query/select.rs @@ -1,12 +1,13 @@ use std::{collections::HashSet, fmt::Display}; pub use ast::Distinctness; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use turso_parser::ast::{self, fmt::ToTokens, SortOrder}; use crate::model::{ query::EmptyContext, - table::{JoinType, JoinedTable}, + table::{JoinTable, JoinType, JoinedTable, Table}, }; use super::predicate::Predicate; @@ -14,6 +15,7 @@ use super::predicate::Predicate; /// `SELECT` or `RETURNING` result column // https://sqlite.org/syntax/result-column.html #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(clippy::large_enum_variant)] pub enum ResultColumn { /// expression Expr(Predicate), @@ -33,9 +35,9 @@ impl Display for ResultColumn { } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Select { - pub(crate) body: SelectBody, - pub(crate) limit: Option, +pub struct Select { + pub body: SelectBody, + pub limit: Option, } impl Select { @@ -102,7 +104,7 @@ impl Select { } } - pub(crate) fn dependencies(&self) -> HashSet { + pub fn dependencies(&self) -> HashSet { if self.body.select.from.is_none() { return HashSet::new(); } @@ -209,13 +211,64 @@ impl FromClause { } } - pub(crate) fn dependencies(&self) -> Vec { + pub fn dependencies(&self) -> Vec { let mut deps = vec![self.table.clone()]; for join in &self.joins { deps.push(join.table.clone()); } deps } + + pub fn into_join_table(&self, tables: &[Table]) -> JoinTable { + let first_table = tables + .iter() + .find(|t| t.name == self.table) + .expect("Table not found"); + + let mut join_table = JoinTable { + tables: vec![first_table.clone()], + rows: Vec::new(), + }; + + for join in &self.joins { + let joined_table = tables + .iter() + .find(|t| t.name == join.table) + .expect("Joined table not found"); + + join_table.tables.push(joined_table.clone()); + + match join.join_type { + JoinType::Inner => { + // Implement inner join logic + let join_rows = joined_table + .rows + .iter() + .filter(|row| join.on.test(row, joined_table)) + .cloned() + .collect::>(); + // take a cartesian product of the rows + let all_row_pairs = join_table + .rows + .clone() + .into_iter() + .cartesian_product(join_rows.iter()); + + for (row1, row2) in all_row_pairs { + let row = row1.iter().chain(row2.iter()).cloned().collect::>(); + + let is_in = join.on.test(&row, &join_table); + + if is_in { + join_table.rows.push(row); + } + } + } + _ => todo!(), + } + } + join_table + } } impl Select { @@ -302,7 +355,7 @@ impl Select { }) .collect() }) - .unwrap_or(Vec::new()), + .unwrap_or_default(), limit: self.limit.map(|l| ast::Limit { expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())).into_boxed(), offset: None, diff --git a/sql_generation/model/query/transaction.rs b/sql_generation/model/query/transaction.rs index 2280357fa..1114200a0 100644 --- a/sql_generation/model/query/transaction.rs +++ b/sql_generation/model/query/transaction.rs @@ -3,15 +3,15 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Begin { - pub(crate) immediate: bool, +pub struct Begin { + pub immediate: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Commit; +pub struct Commit; #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Rollback; +pub struct Rollback; impl Display for Begin { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/sql_generation/model/query/update.rs b/sql_generation/model/query/update.rs index c7c3a5a58..412731bbe 100644 --- a/sql_generation/model/query/update.rs +++ b/sql_generation/model/query/update.rs @@ -7,10 +7,10 @@ use crate::model::table::SimValue; use super::predicate::Predicate; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Update { - pub(crate) table: String, - pub(crate) set_values: Vec<(String, SimValue)>, // Pair of value for set expressions => SET name=value - pub(crate) predicate: Predicate, +pub struct Update { + pub table: String, + pub set_values: Vec<(String, SimValue)>, // Pair of value for set expressions => SET name=value + pub predicate: Predicate, } impl Update { diff --git a/sql_generation/model/table.rs b/sql_generation/model/table.rs index 210039e17..87057b42b 100644 --- a/sql_generation/model/table.rs +++ b/sql_generation/model/table.rs @@ -6,7 +6,7 @@ use turso_parser::ast; use crate::model::query::predicate::Predicate; -pub(crate) struct Name(pub(crate) String); +pub struct Name(pub String); impl Deref for Name { type Target = str; @@ -41,11 +41,11 @@ impl TableContext for Table { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Table { - pub(crate) name: String, - pub(crate) columns: Vec, - pub(crate) rows: Vec>, - pub(crate) indexes: Vec, +pub struct Table { + pub name: String, + pub columns: Vec, + pub rows: Vec>, + pub indexes: Vec, } impl Table { @@ -60,11 +60,11 @@ impl Table { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Column { - pub(crate) name: String, - pub(crate) column_type: ColumnType, - pub(crate) primary: bool, - pub(crate) unique: bool, +pub struct Column { + pub name: String, + pub column_type: ColumnType, + pub primary: bool, + pub unique: bool, } // Uniquely defined by name in this case @@ -83,7 +83,7 @@ impl PartialEq for Column { impl Eq for Column {} #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum ColumnType { +pub enum ColumnType { Integer, Float, Text, @@ -136,23 +136,8 @@ pub struct JoinTable { pub rows: Vec>, } -fn float_to_string(float: &f64, serializer: S) -> Result -where - S: serde::Serializer, -{ - serializer.serialize_str(&format!("{float}")) -} - -fn string_to_float<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - s.parse().map_err(serde::de::Error::custom) -} - #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)] -pub(crate) struct SimValue(pub turso_core::Value); +pub struct SimValue(pub turso_core::Value); fn to_sqlite_blob(bytes: &[u8]) -> String { format!( From d3240844ec843499b510947022964f8e76cc9e86 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 15:53:45 -0300 Subject: [PATCH 62/73] refactor Core to remove the double indirection --- core/incremental/expr_compiler.rs | 4 +- core/translate/display.rs | 17 +++----- core/translate/expr.rs | 10 ++--- core/translate/group_by.rs | 6 +-- core/translate/insert.rs | 9 ++-- .../optimizer/lift_common_subexpressions.rs | 41 ++++++++----------- core/translate/optimizer/mod.rs | 2 +- core/translate/optimizer/order.rs | 4 +- core/translate/order_by.rs | 6 +-- core/translate/plan.rs | 8 ++-- core/translate/planner.rs | 8 ++-- core/translate/select.rs | 19 +++------ core/translate/update.rs | 4 +- core/util.rs | 20 ++++----- 14 files changed, 67 insertions(+), 91 deletions(-) diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs index f94d72a2a..dae0687a2 100644 --- a/core/incremental/expr_compiler.rs +++ b/core/incremental/expr_compiler.rs @@ -67,7 +67,7 @@ fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { distinctness: *distinctness, args: args .iter() - .map(|arg| Box::new(transform_expr_for_dbsp(arg, input_column_names))) + .map(|arg| transform_expr_for_dbsp(arg, input_column_names)) .collect(), order_by: order_by.clone(), filter_over: filter_over.clone(), @@ -75,7 +75,7 @@ fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { Expr::Parenthesized(exprs) => Expr::Parenthesized( exprs .iter() - .map(|e| Box::new(transform_expr_for_dbsp(e, input_column_names))) + .map(|e| transform_expr_for_dbsp(e, input_column_names)) .collect(), ), // For other expression types, keep as is diff --git a/core/translate/display.rs b/core/translate/display.rs index 631e5a295..e69c843dd 100644 --- a/core/translate/display.rs +++ b/core/translate/display.rs @@ -368,13 +368,8 @@ impl ToTokens for SelectPlan { context: &C, ) -> Result<(), S::Error> { if !self.values.is_empty() { - ast::OneSelect::Values( - self.values - .iter() - .map(|values| values.iter().map(|v| Box::from(v.clone())).collect()) - .collect(), - ) - .to_tokens_with_context(s, context)?; + ast::OneSelect::Values(self.values.iter().map(|values| values.to_vec()).collect()) + .to_tokens_with_context(s, context)?; } else { s.append(TokenType::TK_SELECT, None)?; if self.distinctness.is_distinct() { @@ -448,7 +443,7 @@ impl ToTokens for SelectPlan { s.comma( self.order_by.iter().map(|(expr, order)| ast::SortedColumn { - expr: expr.clone(), + expr: expr.clone().into_boxed(), order: Some(*order), nulls: None, }), @@ -510,7 +505,7 @@ impl ToTokens for DeletePlan { s.comma( self.order_by.iter().map(|(expr, order)| ast::SortedColumn { - expr: expr.clone(), + expr: expr.clone().into_boxed(), order: Some(*order), nulls: None, }), @@ -563,7 +558,7 @@ impl ToTokens for UpdatePlan { ast::Set { col_names: vec![ast::Name::new(col_name)], - expr: set_expr.clone(), + expr: set_expr.clone().into_boxed(), } }), context, @@ -591,7 +586,7 @@ impl ToTokens for UpdatePlan { s.comma( self.order_by.iter().map(|(expr, order)| ast::SortedColumn { - expr: expr.clone(), + expr: expr.clone().into_boxed(), order: Some(*order), nulls: None, }), diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 398b31c3c..0e6873561 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -154,7 +154,7 @@ fn translate_in_list( program: &mut ProgramBuilder, referenced_tables: Option<&TableReferences>, lhs: &ast::Expr, - rhs: &[Box], + rhs: &[ast::Expr], not: bool, condition_metadata: ConditionMetadata, resolver: &Resolver, @@ -1633,9 +1633,7 @@ pub fn translate_expr( ); } - if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = - args[1].as_ref() - { + if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { if let Ok(probability) = value.parse::() { if !(0.0..=1.0).contains(&probability) { crate::bail_parse_error!( @@ -2545,7 +2543,7 @@ fn translate_like_base( /// Returns the target register for the function. fn translate_function( program: &mut ProgramBuilder, - args: &[Box], + args: &[ast::Expr], referenced_tables: Option<&TableReferences>, resolver: &Resolver, target_register: usize, @@ -2671,7 +2669,7 @@ pub fn unwrap_parens_owned(expr: ast::Expr) -> Result<(ast::Expr, usize)> { ast::Expr::Parenthesized(mut exprs) => match exprs.len() { 1 => { paren_count += 1; - let (expr, count) = unwrap_parens_owned(*exprs.pop().unwrap().clone())?; + let (expr, count) = unwrap_parens_owned(exprs.pop().unwrap().clone())?; paren_count += count; Ok((expr, paren_count)) } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 0a524f348..4795f3fee 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -85,7 +85,7 @@ pub fn init_group_by<'a>( group_by: &'a GroupBy, plan: &SelectPlan, result_columns: &'a [ResultSetColumn], - order_by: &'a [(Box, ast::SortOrder)], + order_by: &'a [(ast::Expr, ast::SortOrder)], ) -> Result<()> { collect_non_aggregate_expressions( &mut t_ctx.non_aggregate_expressions, @@ -238,13 +238,13 @@ fn collect_non_aggregate_expressions<'a>( group_by: &'a GroupBy, plan: &SelectPlan, root_result_columns: &'a [ResultSetColumn], - order_by: &'a [(Box, ast::SortOrder)], + order_by: &'a [(ast::Expr, ast::SortOrder)], ) -> Result<()> { let mut result_columns = Vec::new(); for expr in root_result_columns .iter() .map(|col| &col.expr) - .chain(order_by.iter().map(|(e, _)| e.as_ref())) + .chain(order_by.iter().map(|(e, _)| e)) .chain(group_by.having.iter().flatten()) { collect_result_columns(expr, plan, &mut result_columns)?; diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 3e12204f4..b7560ceb5 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -100,7 +100,7 @@ pub fn translate_insert( let root_page = btree_table.root_page; - let mut values: Option>> = None; + let mut values: Option> = None; let inserting_multiple_rows = match &mut body { InsertBody::Select(select, _) => match &mut select.body.select { // TODO see how to avoid clone @@ -110,11 +110,10 @@ pub fn translate_insert( } let mut param_idx = 1; for expr in values_expr.iter_mut().flat_map(|v| v.iter_mut()) { - match expr.as_mut() { + match expr { Expr::Id(name) => { if name.is_double_quoted() { - *expr = - Expr::Literal(ast::Literal::String(format!("{name}"))).into(); + *expr = Expr::Literal(ast::Literal::String(name.to_string())); } else { // an INSERT INTO ... VALUES (...) cannot reference columns crate::bail_parse_error!("no such column: {name}"); @@ -838,7 +837,7 @@ fn translate_rows_multiple<'short, 'long: 'short>( #[allow(clippy::too_many_arguments)] fn translate_rows_single( program: &mut ProgramBuilder, - value: &[Box], + value: &[Expr], insertion: &Insertion, resolver: &Resolver, ) -> Result<()> { diff --git a/core/translate/optimizer/lift_common_subexpressions.rs b/core/translate/optimizer/lift_common_subexpressions.rs index a66a8ab1e..6da7c534a 100644 --- a/core/translate/optimizer/lift_common_subexpressions.rs +++ b/core/translate/optimizer/lift_common_subexpressions.rs @@ -104,7 +104,7 @@ pub(crate) fn lift_common_subexpressions_from_binary_or_terms( // If we unwrapped parentheses before, let's add them back. let mut top_level_expr = rebuild_and_expr_from_list(conjunct_list_for_or_branch); while num_unwrapped_parens > 0 { - top_level_expr = Expr::Parenthesized(vec![top_level_expr.into()]); + top_level_expr = Expr::Parenthesized(vec![top_level_expr]); num_unwrapped_parens -= 1; } new_or_operands_for_original_term.push(top_level_expr); @@ -246,13 +246,11 @@ mod tests { let or_expr = Expr::Binary( Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), x_expr.clone(), b_expr.clone()], - ) - .into()])), + )])), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), y_expr.clone(), b_expr.clone()], - ) - .into()])), + )])), ); let mut where_clause = vec![WhereTerm { @@ -275,9 +273,9 @@ mod tests { assert_eq!( nonconsumed_terms[0].expr, Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr.clone().into()])), + Box::new(ast::Expr::Parenthesized(vec![x_expr.clone()])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr.clone().into()])) + Box::new(ast::Expr::Parenthesized(vec![y_expr.clone()])) ) ); assert_eq!(nonconsumed_terms[1].expr, a_expr); @@ -342,19 +340,16 @@ mod tests { Box::new(Expr::Binary( Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), x_expr.clone()], - ) - .into()])), + )])), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), y_expr.clone()], - ) - .into()])), + )])), )), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), z_expr.clone()], - ) - .into()])), + )])), ); let mut where_clause = vec![WhereTerm { @@ -377,12 +372,12 @@ mod tests { nonconsumed_terms[0].expr, Expr::Binary( Box::new(Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])), + Box::new(ast::Expr::Parenthesized(vec![x_expr])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])), + Box::new(ast::Expr::Parenthesized(vec![y_expr])), )), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![z_expr.into()])), + Box::new(ast::Expr::Parenthesized(vec![z_expr])), ) ); assert_eq!(nonconsumed_terms[1].expr, a_expr); @@ -419,9 +414,9 @@ mod tests { ); let or_expr = Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])), + Box::new(ast::Expr::Parenthesized(vec![x_expr])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])), + Box::new(ast::Expr::Parenthesized(vec![y_expr])), ); let mut where_clause = vec![WhereTerm { @@ -484,13 +479,11 @@ mod tests { let or_expr = Expr::Binary( Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), x_expr.clone()], - ) - .into()])), + )])), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), y_expr.clone()], - ) - .into()])), + )])), ); let mut where_clause = vec![WhereTerm { @@ -510,9 +503,9 @@ mod tests { assert_eq!( nonconsumed_terms[0].expr, Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])), + Box::new(ast::Expr::Parenthesized(vec![x_expr])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])) + Box::new(ast::Expr::Parenthesized(vec![y_expr])) ) ); assert_eq!( diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 8502ca005..005483dec 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -186,7 +186,7 @@ fn optimize_table_access( table_references: &mut TableReferences, available_indexes: &HashMap>>, where_clause: &mut [WhereTerm], - order_by: &mut Vec<(Box, SortOrder)>, + order_by: &mut Vec<(ast::Expr, SortOrder)>, group_by: &mut Option, ) -> Result>> { let access_methods_arena = RefCell::new(Vec::new()); diff --git a/core/translate/optimizer/order.rs b/core/translate/optimizer/order.rs index b7b3c4edc..739fc12b2 100644 --- a/core/translate/optimizer/order.rs +++ b/core/translate/optimizer/order.rs @@ -71,7 +71,7 @@ impl OrderTarget { /// TODO: this does not currently handle the case where we definitely cannot eliminate /// the ORDER BY sorter, but we could still eliminate the GROUP BY sorter. pub fn compute_order_target( - order_by: &mut Vec<(Box, SortOrder)>, + order_by: &mut Vec<(ast::Expr, SortOrder)>, group_by_opt: Option<&mut GroupBy>, ) -> Option { match (order_by.is_empty(), group_by_opt) { @@ -79,7 +79,7 @@ pub fn compute_order_target( (true, None) => None, // Only ORDER BY - we would like the joined result rows to be in the order specified by the ORDER BY (false, None) => OrderTarget::maybe_from_iterator( - order_by.iter().map(|(expr, order)| (expr.as_ref(), *order)), + order_by.iter().map(|(expr, order)| (expr, *order)), EliminatesSortBy::Order, ), // Only GROUP BY - we would like the joined result rows to be in the order specified by the GROUP BY diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index 4993e8010..c825592f3 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -36,7 +36,7 @@ pub fn init_order_by( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, result_columns: &[ResultSetColumn], - order_by: &[(Box, SortOrder)], + order_by: &[(ast::Expr, SortOrder)], referenced_tables: &TableReferences, ) -> Result<()> { let sort_cursor = program.alloc_cursor_id(CursorType::Sorter); @@ -55,7 +55,7 @@ pub fn init_order_by( */ let collations = order_by .iter() - .map(|(expr, _)| match expr.as_ref() { + .map(|(expr, _)| match expr { ast::Expr::Collate(_, collation_name) => { CollationSeq::new(collation_name.as_str()).map(Some) } @@ -324,7 +324,7 @@ pub struct OrderByRemapping { /// /// If any result columns can be skipped, this returns list of 2-tuples of (SkippedResultColumnIndex: usize, ResultColumnIndexInOrderBySorter: usize) pub fn order_by_deduplicate_result_columns( - order_by: &[(Box, SortOrder)], + order_by: &[(ast::Expr, SortOrder)], result_columns: &[ResultSetColumn], ) -> Vec { let mut result_column_remapping: Vec = Vec::new(); diff --git a/core/translate/plan.rs b/core/translate/plan.rs index eba50ce89..9dcbb669b 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -288,7 +288,7 @@ pub struct SelectPlan { /// group by clause pub group_by: Option, /// order by clause - pub order_by: Vec<(Box, SortOrder)>, + pub order_by: Vec<(ast::Expr, SortOrder)>, /// all the aggregates collected from the result columns, order by, and (TODO) having clauses pub aggregates: Vec, /// limit clause @@ -376,7 +376,7 @@ pub struct DeletePlan { /// where clause split into a vec at 'AND' boundaries. pub where_clause: Vec, /// order by clause - pub order_by: Vec<(Box, SortOrder)>, + pub order_by: Vec<(ast::Expr, SortOrder)>, /// limit clause pub limit: Option, /// offset clause @@ -391,9 +391,9 @@ pub struct DeletePlan { pub struct UpdatePlan { pub table_references: TableReferences, // (colum index, new value) pairs - pub set_clauses: Vec<(usize, Box)>, + pub set_clauses: Vec<(usize, ast::Expr)>, pub where_clause: Vec, - pub order_by: Vec<(Box, SortOrder)>, + pub order_by: Vec<(ast::Expr, SortOrder)>, pub limit: Option, pub offset: Option, // TODO: optional RETURNING clause diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 43f012875..34287a713 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -75,7 +75,7 @@ pub fn resolve_aggregates( } aggs.push(Aggregate { func: f, - args: args.iter().map(|arg| *arg.clone()).collect(), + args: args.to_vec(), original_expr: expr.clone(), distinctness, }); @@ -411,7 +411,7 @@ fn parse_table( vtab_predicates: &mut Vec, qualified_name: &QualifiedName, maybe_alias: Option<&As>, - args: &[Box], + args: &[Expr], connection: &Arc, ) -> Result<()> { let normalized_qualified_name = normalize_ident(qualified_name.name.as_str()); @@ -547,7 +547,7 @@ fn parse_table( } fn transform_args_into_where_terms( - args: &[Box], + args: &[Expr], internal_id: TableInternalId, predicates: &mut Vec, table: &Table, @@ -567,7 +567,7 @@ fn transform_args_into_where_terms( column: i, is_rowid_alias: col.is_rowid_alias, }; - let expr = match arg_expr.as_ref() { + let expr = match arg_expr { Expr::Literal(Null) => Expr::IsNull(Box::new(column_expr)), other => Expr::Binary( column_expr.into(), diff --git a/core/translate/select.rs b/core/translate/select.rs index d276471aa..580312077 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -376,8 +376,7 @@ fn prepare_one_select_plan( // COUNT() case vec![ast::Expr::Literal(ast::Literal::Numeric( "1".to_string(), - )) - .into()] + ))] } (true, _) => crate::bail_parse_error!( "Aggregate function {} requires arguments", @@ -388,7 +387,7 @@ fn prepare_one_select_plan( let agg = Aggregate { func: f, - args: agg_args.iter().map(|arg| *arg.clone()).collect(), + args: agg_args.to_vec(), original_expr: *expr.clone(), distinctness, }; @@ -448,10 +447,7 @@ fn prepare_one_select_plan( } else { let agg = Aggregate { func: AggFunc::External(f.func.clone().into()), - args: args - .iter() - .map(|arg| *arg.clone()) - .collect(), + args: args.to_vec(), original_expr: *expr.clone(), distinctness, }; @@ -571,7 +567,7 @@ fn prepare_one_select_plan( plan.group_by = Some(GroupBy { sort_order: Some((0..group_by.exprs.len()).map(|_| SortOrder::Asc).collect()), - exprs: group_by.exprs.iter().map(|expr| *expr.clone()).collect(), + exprs: group_by.exprs.to_vec(), having: if let Some(having) = group_by.having { let mut predicates = vec![]; break_predicate_at_and_boundaries(&having, &mut predicates); @@ -617,7 +613,7 @@ fn prepare_one_select_plan( )?; resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?; - key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc))); + key.push((*o.expr, o.order.unwrap_or(ast::SortOrder::Asc))); } plan.order_by = key; @@ -651,10 +647,7 @@ fn prepare_one_select_plan( contains_constant_false_condition: false, query_destination, distinctness: Distinctness::NonDistinct, - values: values - .iter() - .map(|values| values.iter().map(|value| *value.clone()).collect()) - .collect(), + values: values.iter().map(|values| values.to_vec()).collect(), }; Ok(plan) diff --git a/core/translate/update.rs b/core/translate/update.rs index f94ebe118..a2e08e9fc 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -175,7 +175,7 @@ pub fn prepare_update_plan( let values = match set.expr.as_ref() { Expr::Parenthesized(vals) => vals.clone(), - expr => vec![expr.clone().into()], + expr => vec![expr.clone()], }; if set.col_names.len() != values.len() { @@ -213,7 +213,7 @@ pub fn prepare_update_plan( let order_by = body .order_by .iter() - .map(|o| (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc))) + .map(|o| (*o.expr.clone(), o.order.unwrap_or(SortOrder::Asc))) .collect(); // Sqlite determines we should create an ephemeral table if we do not have a FROM clause diff --git a/core/util.rs b/core/util.rs index 9b7d9e2c8..81bfdedb0 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1190,8 +1190,8 @@ pub fn parse_pragma_bool(expr: &Expr) -> Result { } /// Extract column name from an expression (e.g., for SELECT clauses) -pub fn extract_column_name_from_expr(expr: impl AsRef) -> Option { - match expr.as_ref() { +pub fn extract_column_name_from_expr(expr: &ast::Expr) -> Option { + match expr { ast::Expr::Id(name) => Some(name.as_str().to_string()), ast::Expr::Qualified(_, name) => Some(name.as_str().to_string()), _ => None, @@ -1435,7 +1435,7 @@ pub mod tests { let func1 = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: None, - args: vec![Expr::Id(Name::Ident("x".to_string())).into()], + args: vec![Expr::Id(Name::Ident("x".to_string()))], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1445,7 +1445,7 @@ pub mod tests { let func2 = Expr::FunctionCall { name: Name::Ident("sum".to_string()), distinctness: None, - args: vec![Expr::Id(Name::Ident("x".to_string())).into()], + args: vec![Expr::Id(Name::Ident("x".to_string()))], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1457,7 +1457,7 @@ pub mod tests { let func3 = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: Some(ast::Distinctness::Distinct), - args: vec![Expr::Id(Name::Ident("x".to_string())).into()], + args: vec![Expr::Id(Name::Ident("x".to_string()))], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1472,7 +1472,7 @@ pub mod tests { let sum = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: None, - args: vec![Expr::Id(Name::Ident("x".to_string())).into()], + args: vec![Expr::Id(Name::Ident("x".to_string()))], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1482,7 +1482,7 @@ pub mod tests { let sum_distinct = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: Some(ast::Distinctness::Distinct), - args: vec![Expr::Id(Name::Ident("x".to_string())).into()], + args: vec![Expr::Id(Name::Ident("x".to_string()))], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1513,8 +1513,7 @@ pub mod tests { Box::new(Expr::Literal(Literal::Numeric("683".to_string()))), Add, Box::new(Expr::Literal(Literal::Numeric("799.0".to_string()))), - ) - .into()]); + )]); let expr2 = Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("799".to_string()))), Add, @@ -1528,8 +1527,7 @@ pub mod tests { Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), Add, Box::new(Expr::Literal(Literal::Numeric("7".to_string()))), - ) - .into()]); + )]); let expr8 = Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), Add, From 8010b7d0c7e332889ab1291ea852665a65d646f8 Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Mon, 25 Aug 2025 21:06:23 -0300 Subject: [PATCH 63/73] make simulator use `sql_generation` crate as dependency --- Cargo.lock | 8 +- Cargo.toml | 1 + simulator/Cargo.toml | 10 +- simulator/generation/expr.rs | 293 ------------ simulator/generation/mod.rs | 165 +------ simulator/generation/plan.rs | 20 +- simulator/generation/predicate/binary.rs | 586 ----------------------- simulator/generation/predicate/mod.rs | 378 --------------- simulator/generation/predicate/unary.rs | 306 ------------ simulator/generation/property.rs | 43 +- simulator/generation/query.rs | 369 +------------- simulator/generation/table.rs | 258 ---------- simulator/main.rs | 4 +- simulator/model/mod.rs | 419 +++++++++++++++- simulator/model/query/create.rs | 45 -- simulator/model/query/create_index.rs | 106 ---- simulator/model/query/delete.rs | 41 -- simulator/model/query/drop.rs | 34 -- simulator/model/query/insert.rs | 87 ---- simulator/model/query/mod.rs | 129 ----- simulator/model/query/predicate.rs | 146 ------ simulator/model/query/select.rs | 497 ------------------- simulator/model/query/transaction.rs | 60 --- simulator/model/query/update.rs | 71 --- simulator/model/table.rs | 428 ----------------- simulator/runner/clock.rs | 2 +- simulator/runner/differential.rs | 4 +- simulator/runner/doublecheck.rs | 5 +- simulator/runner/env.rs | 21 +- simulator/runner/execution.rs | 2 +- simulator/runner/file.rs | 6 +- simulator/runner/mod.rs | 2 + simulator/runner/watch.rs | 7 +- simulator/shrink/plan.rs | 2 +- sql_generation/generation/mod.rs | 5 +- sql_generation/model/query/mod.rs | 69 --- 36 files changed, 499 insertions(+), 4130 deletions(-) delete mode 100644 simulator/generation/expr.rs delete mode 100644 simulator/generation/predicate/binary.rs delete mode 100644 simulator/generation/predicate/mod.rs delete mode 100644 simulator/generation/predicate/unary.rs delete mode 100644 simulator/generation/table.rs delete mode 100644 simulator/model/query/create.rs delete mode 100644 simulator/model/query/create_index.rs delete mode 100644 simulator/model/query/delete.rs delete mode 100644 simulator/model/query/drop.rs delete mode 100644 simulator/model/query/insert.rs delete mode 100644 simulator/model/query/mod.rs delete mode 100644 simulator/model/query/predicate.rs delete mode 100644 simulator/model/query/select.rs delete mode 100644 simulator/model/query/transaction.rs delete mode 100644 simulator/model/query/update.rs delete mode 100644 simulator/model/table.rs diff --git a/Cargo.lock b/Cargo.lock index 1569f69ce..e599b2f8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2134,7 +2134,6 @@ dependencies = [ name = "limbo_sim" version = "0.1.4" dependencies = [ - "anarchist-readable-name-generator-lib 0.1.2", "anyhow", "chrono", "clap", @@ -2144,17 +2143,18 @@ dependencies = [ "itertools 0.14.0", "log", "notify", - "rand 0.8.5", - "rand_chacha 0.3.1", + "rand 0.9.2", + "rand_chacha 0.9.0", "regex", "regex-syntax 0.8.5", "rusqlite", "serde", "serde_json", + "sql_generation", "tracing", "tracing-subscriber", "turso_core", - "turso_sqlite3_parser", + "turso_parser", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 092d76d98..0dbb4b1fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ limbo_regexp = { path = "extensions/regexp", version = "0.1.4" } turso_sqlite3_parser = { path = "vendored/sqlite3-parser", version = "0.1.4" } limbo_uuid = { path = "extensions/uuid", version = "0.1.4" } turso_parser = { path = "parser" } +sql_generation = { path = "sql_generation" } strum = { version = "0.26", features = ["derive"] } strum_macros = "0.26" serde = "1.0" diff --git a/simulator/Cargo.toml b/simulator/Cargo.toml index 20696fa91..f01896716 100644 --- a/simulator/Cargo.toml +++ b/simulator/Cargo.toml @@ -16,15 +16,14 @@ path = "main.rs" [dependencies] turso_core = { path = "../core", features = ["simulator"]} -rand = "0.8.5" -rand_chacha = "0.3.1" +rand = { workspace = true } +rand_chacha = "0.9.0" log = "0.4.20" env_logger = "0.10.1" regex = "1.11.1" regex-syntax = { version = "0.8.5", default-features = false, features = [ "unicode", ] } -anarchist-readable-name-generator-lib = "=0.1.2" clap = { version = "4.5", features = ["derive"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } @@ -32,9 +31,10 @@ notify = "8.0.0" rusqlite.workspace = true dirs = "6.0.0" chrono = { version = "0.4.40", features = ["serde"] } -tracing = "0.1.41" +tracing = { workspace = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } anyhow.workspace = true -turso_sqlite3_parser = { workspace = true, features = ["serde"]} hex = "0.4.3" itertools = "0.14.0" +sql_generation = { workspace = true } +turso_parser = { workspace = true } diff --git a/simulator/generation/expr.rs b/simulator/generation/expr.rs deleted file mode 100644 index 682c38d5c..000000000 --- a/simulator/generation/expr.rs +++ /dev/null @@ -1,293 +0,0 @@ -use turso_sqlite3_parser::ast::{ - self, Expr, LikeOperator, Name, Operator, QualifiedName, Type, UnaryOperator, -}; - -use crate::{ - generation::{ - frequency, gen_random_text, one_of, pick, pick_index, Arbitrary, ArbitraryFrom, - ArbitrarySizedFrom, - }, - model::table::SimValue, - SimulatorEnv, -}; - -impl Arbitrary for Box -where - T: Arbitrary, -{ - fn arbitrary(rng: &mut R) -> Self { - Box::from(T::arbitrary(rng)) - } -} - -impl ArbitrarySizedFrom for Box -where - T: ArbitrarySizedFrom, -{ - fn arbitrary_sized_from(rng: &mut R, t: A, size: usize) -> Self { - Box::from(T::arbitrary_sized_from(rng, t, size)) - } -} - -impl Arbitrary for Option -where - T: Arbitrary, -{ - fn arbitrary(rng: &mut R) -> Self { - rng.gen_bool(0.5).then_some(T::arbitrary(rng)) - } -} - -impl ArbitrarySizedFrom for Option -where - T: ArbitrarySizedFrom, -{ - fn arbitrary_sized_from(rng: &mut R, t: A, size: usize) -> Self { - rng.gen_bool(0.5) - .then_some(T::arbitrary_sized_from(rng, t, size)) - } -} - -impl ArbitraryFrom for Vec -where - T: ArbitraryFrom, -{ - fn arbitrary_from(rng: &mut R, t: A) -> Self { - let size = rng.gen_range(0..5); - (0..size).map(|_| T::arbitrary_from(rng, t)).collect() - } -} - -// Freestyling generation -impl ArbitrarySizedFrom<&SimulatorEnv> for Expr { - fn arbitrary_sized_from(rng: &mut R, t: &SimulatorEnv, size: usize) -> Self { - frequency( - vec![ - ( - 1, - Box::new(|rng| Expr::Literal(ast::Literal::arbitrary_from(rng, t))), - ), - ( - size, - Box::new(|rng| { - one_of( - vec![ - // Box::new(|rng: &mut R| Expr::Between { - // lhs: Box::arbitrary_sized_from(rng, t, size - 1), - // not: rng.gen_bool(0.5), - // start: Box::arbitrary_sized_from(rng, t, size - 1), - // end: Box::arbitrary_sized_from(rng, t, size - 1), - // }), - Box::new(|rng: &mut R| { - Expr::Binary( - Box::arbitrary_sized_from(rng, t, size - 1), - Operator::arbitrary(rng), - Box::arbitrary_sized_from(rng, t, size - 1), - ) - }), - // Box::new(|rng| Expr::Case { - // base: Option::arbitrary_from(rng, t), - // when_then_pairs: { - // let size = rng.gen_range(0..5); - // (0..size) - // .map(|_| (Self::arbitrary_from(rng, t), Self::arbitrary_from(rng, t))) - // .collect() - // }, - // else_expr: Option::arbitrary_from(rng, t), - // }), - // Box::new(|rng| Expr::Cast { - // expr: Box::arbitrary_sized_from(rng, t), - // type_name: Option::arbitrary(rng), - // }), - // Box::new(|rng| Expr::Collate(Box::arbitrary_sized_from(rng, t), CollateName::arbitrary(rng).0)), - // Box::new(|rng| Expr::InList { - // lhs: Box::arbitrary_sized_from(rng, t), - // not: rng.gen_bool(0.5), - // rhs: Option::arbitrary_from(rng, t), - // }), - // Box::new(|rng| Expr::IsNull(Box::arbitrary_sized_from(rng, t))), - // Box::new(|rng| { - // // let op = LikeOperator::arbitrary_from(rng, t); - // let op = ast::LikeOperator::Like; // todo: remove this line when LikeOperator is implemented - // let escape = if matches!(op, LikeOperator::Like) { - // Option::arbitrary_sized_from(rng, t, size - 1) - // } else { - // None - // }; - // Expr::Like { - // lhs: Box::arbitrary_sized_from(rng, t, size - 1), - // not: rng.gen_bool(0.5), - // op, - // rhs: Box::arbitrary_sized_from(rng, t, size - 1), - // escape, - // } - // }), - // Box::new(|rng| Expr::NotNull(Box::arbitrary_sized_from(rng, t))), - // // TODO: only supports one paranthesized expression - // Box::new(|rng| Expr::Parenthesized(vec![Expr::arbitrary_from(rng, t)])), - // Box::new(|rng| { - // let table_idx = pick_index(t.tables.len(), rng); - // let table = &t.tables[table_idx]; - // let col_idx = pick_index(table.columns.len(), rng); - // let col = &table.columns[col_idx]; - // Expr::Qualified(Name(table.name.clone()), Name(col.name.clone())) - // }) - Box::new(|rng| { - Expr::Unary( - UnaryOperator::arbitrary_from(rng, t), - Box::arbitrary_sized_from(rng, t, size - 1), - ) - }), - // TODO: skip Exists for now - // TODO: skip Function Call for now - // TODO: skip Function Call Star for now - // TODO: skip ID for now - // TODO: skip InSelect as still need to implement ArbitratyFrom for Select - // TODO: skip InTable - // TODO: skip Name - // TODO: Skip DoublyQualified for now - // TODO: skip Raise - // TODO: skip subquery - ], - rng, - ) - }), - ), - ], - rng, - ) - } -} - -impl Arbitrary for Operator { - fn arbitrary(rng: &mut R) -> Self { - let choices = [ - Operator::Add, - Operator::And, - // Operator::ArrowRight, -- todo: not implemented in `binary_compare` yet - // Operator::ArrowRightShift, -- todo: not implemented in `binary_compare` yet - Operator::BitwiseAnd, - // Operator::BitwiseNot, -- todo: not implemented in `binary_compare` yet - Operator::BitwiseOr, - // Operator::Concat, -- todo: not implemented in `exec_concat` - Operator::Divide, - Operator::Equals, - Operator::Greater, - Operator::GreaterEquals, - Operator::Is, - Operator::IsNot, - Operator::LeftShift, - Operator::Less, - Operator::LessEquals, - Operator::Modulus, - Operator::Multiply, - Operator::NotEquals, - Operator::Or, - Operator::RightShift, - Operator::Subtract, - ]; - *pick(&choices, rng) - } -} - -impl Arbitrary for Type { - fn arbitrary(rng: &mut R) -> Self { - let name = pick(&["INT", "INTEGER", "REAL", "TEXT", "BLOB", "ANY"], rng).to_string(); - Self { - name, - size: None, // TODO: come back later here - } - } -} - -struct CollateName(String); - -impl Arbitrary for CollateName { - fn arbitrary(rng: &mut R) -> Self { - let choice = rng.gen_range(0..3); - CollateName( - match choice { - 0 => "BINARY", - 1 => "RTRIM", - 2 => "NOCASE", - _ => unreachable!(), - } - .to_string(), - ) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for QualifiedName { - fn arbitrary_from(rng: &mut R, t: &SimulatorEnv) -> Self { - // TODO: for now just generate table name - let table_idx = pick_index(t.tables.len(), rng); - let table = &t.tables[table_idx]; - // TODO: for now forego alias - Self::single(Name::from_str(&table.name)) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for LikeOperator { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { - let choice = rng.gen_range(0..4); - match choice { - 0 => LikeOperator::Glob, - 1 => LikeOperator::Like, - 2 => LikeOperator::Match, - 3 => LikeOperator::Regexp, - _ => unreachable!(), - } - } -} - -// Current implementation does not take into account the columns affinity nor if table is Strict -impl ArbitraryFrom<&SimulatorEnv> for ast::Literal { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { - loop { - let choice = rng.gen_range(0..5); - let lit = match choice { - 0 => ast::Literal::Numeric({ - let integer = rng.gen_bool(0.5); - if integer { - rng.gen_range(i64::MIN..i64::MAX).to_string() - } else { - rng.gen_range(-1e10..1e10).to_string() - } - }), - 1 => ast::Literal::String(format!("'{}'", gen_random_text(rng))), - 2 => ast::Literal::Blob(hex::encode(gen_random_text(rng).as_bytes())), - // TODO: skip Keyword - 3 => continue, - 4 => ast::Literal::Null, - // TODO: Ignore Date stuff for now - _ => continue, - }; - break lit; - } - } -} - -// Creates a litreal value -impl ArbitraryFrom<&Vec<&SimValue>> for ast::Expr { - fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { - if values.is_empty() { - return Self::Literal(ast::Literal::Null); - } - // TODO: for now just convert the value to an ast::Literal - let value = pick(values, rng); - Expr::Literal((*value).into()) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for UnaryOperator { - fn arbitrary_from(rng: &mut R, _t: &SimulatorEnv) -> Self { - let choice = rng.gen_range(0..4); - match choice { - 0 => Self::BitwiseNot, - 1 => Self::Negative, - 2 => Self::Not, - 3 => Self::Positive, - _ => unreachable!(), - } - } -} diff --git a/simulator/generation/mod.rs b/simulator/generation/mod.rs index 6d944b065..79bdf506f 100644 --- a/simulator/generation/mod.rs +++ b/simulator/generation/mod.rs @@ -1,65 +1,10 @@ -use std::{iter::Sum, ops::SubAssign}; +use sql_generation::generation::GenerationContext; -use anarchist_readable_name_generator_lib::readable_name_custom; -use rand::{distributions::uniform::SampleUniform, Rng}; +use crate::runner::env::{SimulatorEnv, SimulatorTables}; -use crate::runner::env::SimulatorTables; - -mod expr; pub mod plan; -mod predicate; pub mod property; pub mod query; -pub mod table; - -type ArbitraryFromFunc<'a, R, T> = Box T + 'a>; -type Choice<'a, R, T> = (usize, Box Option + 'a>); - -/// Arbitrary trait for generating random values -/// An implementation of arbitrary is assumed to be a uniform sampling of -/// the possible values of the type, with a bias towards smaller values for -/// practicality. -pub trait Arbitrary { - fn arbitrary(rng: &mut R) -> Self; -} - -/// ArbitrarySized trait for generating random values of a specific size -/// An implementation of arbitrary_sized is assumed to be a uniform sampling of -/// the possible values of the type, with a bias towards smaller values for -/// practicality, but with the additional constraint that the generated value -/// must fit in the given size. This is useful for generating values that are -/// constrained by a specific size, such as integers or strings. -pub trait ArbitrarySized { - fn arbitrary_sized(rng: &mut R, size: usize) -> Self; -} - -/// ArbitraryFrom trait for generating random values from a given value -/// ArbitraryFrom allows for constructing relations, where the generated -/// value is dependent on the given value. These relations could be constraints -/// such as generating an integer within an interval, or a value that fits in a table, -/// or a predicate satisfying a given table row. -pub trait ArbitraryFrom { - fn arbitrary_from(rng: &mut R, t: T) -> Self; -} - -/// ArbitrarySizedFrom trait for generating random values from a given value -/// ArbitrarySizedFrom allows for constructing relations, where the generated -/// value is dependent on the given value and a size constraint. These relations -/// could be constraints such as generating an integer within an interval, -/// or a value that fits in a table, or a predicate satisfying a given table row, -/// but with the additional constraint that the generated value must fit in the given size. -/// This is useful for generating values that are constrained by a specific size, -/// such as integers or strings, while still being dependent on the given value. -pub trait ArbitrarySizedFrom { - fn arbitrary_sized_from(rng: &mut R, t: T, size: usize) -> Self; -} - -/// ArbitraryFromMaybe trait for fallibally generating random values from a given value -pub trait ArbitraryFromMaybe { - fn arbitrary_from_maybe(rng: &mut R, t: T) -> Option - where - Self: Sized; -} /// Shadow trait for types that can be "shadowed" in the simulator environment. /// Shadowing is a process of applying a transformation to the simulator environment @@ -75,108 +20,14 @@ pub(crate) trait Shadow { fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result; } -/// Frequency is a helper function for composing different generators with different frequency -/// of occurrences. -/// The type signature for the `N` parameter is a bit complex, but it -/// roughly corresponds to a type that can be summed, compared, subtracted and sampled, which are -/// the operations we require for the implementation. -// todo: switch to a simpler type signature that can accommodate all integer and float types, which -// should be enough for our purposes. -pub(crate) fn frequency< - T, - R: Rng, - N: Sum + PartialOrd + Copy + Default + SampleUniform + SubAssign, ->( - choices: Vec<(N, ArbitraryFromFunc)>, - rng: &mut R, -) -> T { - let total = choices.iter().map(|(weight, _)| *weight).sum::(); - let mut choice = rng.gen_range(N::default()..total); - - for (weight, f) in choices { - if choice < weight { - return f(rng); - } - choice -= weight; +impl GenerationContext for SimulatorEnv { + fn tables(&self) -> &Vec { + &self.tables.tables } - unreachable!() -} - -/// one_of is a helper function for composing different generators with equal probability of occurrence. -pub(crate) fn one_of(choices: Vec>, rng: &mut R) -> T { - let index = rng.gen_range(0..choices.len()); - choices[index](rng) -} - -/// backtrack is a helper function for composing different "failable" generators. -/// The function takes a list of functions that return an Option, along with number of retries -/// to make before giving up. -pub(crate) fn backtrack(mut choices: Vec>, rng: &mut R) -> Option { - loop { - // If there are no more choices left, we give up - let choices_ = choices - .iter() - .enumerate() - .filter(|(_, (retries, _))| *retries > 0) - .collect::>(); - if choices_.is_empty() { - tracing::trace!("backtrack: no more choices left"); - return None; - } - // Run a one_of on the remaining choices - let (choice_index, choice) = pick(&choices_, rng); - let choice_index = *choice_index; - // If the choice returns None, we decrement the number of retries and try again - let result = choice.1(rng); - if result.is_some() { - return result; - } else { - choices[choice_index].0 -= 1; + fn opts(&self) -> sql_generation::generation::Opts { + sql_generation::generation::Opts { + indexes: self.opts.experimental_indexes, } } } - -/// pick is a helper function for uniformly picking a random element from a slice -pub(crate) fn pick<'a, T, R: Rng>(choices: &'a [T], rng: &mut R) -> &'a T { - let index = rng.gen_range(0..choices.len()); - &choices[index] -} - -/// pick_index is typically used for picking an index from a slice to later refer to the element -/// at that index. -pub(crate) fn pick_index(choices: usize, rng: &mut R) -> usize { - rng.gen_range(0..choices) -} - -/// pick_n_unique is a helper function for uniformly picking N unique elements from a range. -/// The elements themselves are usize, typically representing indices. -pub(crate) fn pick_n_unique( - range: std::ops::Range, - n: usize, - rng: &mut R, -) -> Vec { - use rand::seq::SliceRandom; - let mut items: Vec = range.collect(); - items.shuffle(rng); - items.into_iter().take(n).collect() -} - -/// gen_random_text uses `anarchist_readable_name_generator_lib` to generate random -/// readable names for tables, columns, text values etc. -pub(crate) fn gen_random_text(rng: &mut T) -> String { - let big_text = rng.gen_ratio(1, 1000); - if big_text { - // let max_size: u64 = 2 * 1024 * 1024 * 1024; - let max_size: u64 = 2 * 1024; - let size = rng.gen_range(1024..max_size); - let mut name = String::with_capacity(size as usize); - for i in 0..size { - name.push(((i % 26) as u8 + b'A') as char); - } - name - } else { - let name = readable_name_custom("_", rng); - name.replace("-", "_") - } -} diff --git a/simulator/generation/plan.rs b/simulator/generation/plan.rs index c27becad5..47763657a 100644 --- a/simulator/generation/plan.rs +++ b/simulator/generation/plan.rs @@ -8,14 +8,18 @@ use std::{ use serde::{Deserialize, Serialize}; +use sql_generation::{ + generation::{frequency, query::SelectFree, Arbitrary, ArbitraryFrom}, + model::{ + query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Select}, + table::SimValue, + }, +}; use turso_core::{Connection, Result, StepResult}; use crate::{ - generation::{query::SelectFree, Shadow}, - model::{ - query::{update::Update, Create, CreateIndex, Delete, Drop, Insert, Query, Select}, - table::SimValue, - }, + generation::Shadow, + model::Query, runner::{ env::{SimConnection, SimulationType, SimulatorTables}, io::SimulatorIO, @@ -23,8 +27,6 @@ use crate::{ SimulatorEnv, }; -use crate::generation::{frequency, Arbitrary, ArbitraryFrom}; - use super::property::{remaining, Property}; pub(crate) type ResultSet = Result>>; @@ -661,7 +663,7 @@ impl Interaction { .iter() .any(|file| file.sync_completion.borrow().is_some()) }; - let inject_fault = env.rng.gen_bool(current_prob); + let inject_fault = env.rng.random_bool(current_prob); // TODO: avoid for now injecting faults when syncing if inject_fault && !syncing { env.io.inject_fault(true); @@ -811,7 +813,7 @@ fn random_fault(rng: &mut R, env: &SimulatorEnv) -> Interactions { } else { vec![Fault::Disconnect, Fault::ReopenDatabase] }; - let fault = faults[rng.gen_range(0..faults.len())].clone(); + let fault = faults[rng.random_range(0..faults.len())].clone(); Interactions::Fault(fault) } diff --git a/simulator/generation/predicate/binary.rs b/simulator/generation/predicate/binary.rs deleted file mode 100644 index f8ba27236..000000000 --- a/simulator/generation/predicate/binary.rs +++ /dev/null @@ -1,586 +0,0 @@ -//! Contains code for generation for [ast::Expr::Binary] Predicate - -use turso_sqlite3_parser::ast::{self, Expr}; - -use crate::{ - generation::{ - backtrack, one_of, pick, - predicate::{CompoundPredicate, SimplePredicate}, - table::{GTValue, LTValue, LikeValue}, - ArbitraryFrom, ArbitraryFromMaybe as _, - }, - model::{ - query::predicate::Predicate, - table::{SimValue, Table, TableContext}, - }, -}; - -impl Predicate { - /// Generate an [ast::Expr::Binary] [Predicate] from a column and [SimValue] - pub fn from_column_binary( - rng: &mut R, - column_name: &str, - value: &SimValue, - ) -> Predicate { - let expr = one_of( - vec![ - Box::new(|_| { - Expr::Binary( - Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))), - ast::Operator::Equals, - Box::new(Expr::Literal(value.into())), - ) - }), - Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, value).0; - Expr::Binary( - Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))), - ast::Operator::Greater, - Box::new(Expr::Literal(gt_value.into())), - ) - }), - Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, value).0; - Expr::Binary( - Box::new(Expr::Id(ast::Name::Ident(column_name.to_string()))), - ast::Operator::Less, - Box::new(Expr::Literal(lt_value.into())), - ) - }), - ], - rng, - ); - Predicate(expr) - } - - /// Produces a true [ast::Expr::Binary] [Predicate] that is true for the provided row in the given table - pub fn true_binary(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate { - // Pick a column - let column_index = rng.gen_range(0..t.columns.len()); - let mut column = t.columns[column_index].clone(); - let value = &row[column_index]; - - let mut table_name = t.name.clone(); - if t.name.is_empty() { - // If the table name is empty, we cannot create a qualified expression - // so we use the column name directly - let mut splitted = column.name.split('.'); - table_name = splitted - .next() - .expect("Column name should have a table prefix for a joined table") - .to_string(); - column.name = splitted - .next() - .expect("Column name should have a column suffix for a joined table") - .to_string(); - } - - let expr = backtrack( - vec![ - ( - 1, - Box::new(|_| { - Some(Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::Equals, - Box::new(Expr::Literal(value.into())), - )) - }), - ), - ( - 1, - Box::new(|rng| { - let v = SimValue::arbitrary_from(rng, &column.column_type); - if &v == value { - None - } else { - Some(Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::NotEquals, - Box::new(Expr::Literal(v.into())), - )) - } - }), - ), - ( - 1, - Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, value).0; - Some(Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::Greater, - Box::new(Expr::Literal(lt_value.into())), - )) - }), - ), - ( - 1, - Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, value).0; - Some(Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::Less, - Box::new(Expr::Literal(gt_value.into())), - )) - }), - ), - ( - 1, - Box::new(|rng| { - // TODO: generation for Like and Glob expressions should be extracted to different module - LikeValue::arbitrary_from_maybe(rng, value).map(|like| { - Expr::Like { - lhs: Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - not: false, // TODO: also generate this value eventually - op: ast::LikeOperator::Like, - rhs: Box::new(Expr::Literal(like.0.into())), - escape: None, // TODO: implement - } - }) - }), - ), - ], - rng, - ); - // Backtrack will always return Some here - Predicate(expr.unwrap()) - } - - /// Produces an [ast::Expr::Binary] [Predicate] that is false for the provided row in the given table - pub fn false_binary(rng: &mut R, t: &Table, row: &[SimValue]) -> Predicate { - // Pick a column - let column_index = rng.gen_range(0..t.columns.len()); - let mut column = t.columns[column_index].clone(); - let mut table_name = t.name.clone(); - let value = &row[column_index]; - - if t.name.is_empty() { - // If the table name is empty, we cannot create a qualified expression - // so we use the column name directly - let mut splitted = column.name.split('.'); - table_name = splitted - .next() - .expect("Column name should have a table prefix for a joined table") - .to_string(); - column.name = splitted - .next() - .expect("Column name should have a column suffix for a joined table") - .to_string(); - } - - let expr = one_of( - vec![ - Box::new(|_| { - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::NotEquals, - Box::new(Expr::Literal(value.into())), - ) - }), - Box::new(|rng| { - let v = loop { - let v = SimValue::arbitrary_from(rng, &column.column_type); - if &v != value { - break v; - } - }; - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::Equals, - Box::new(Expr::Literal(v.into())), - ) - }), - Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, value).0; - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::Greater, - Box::new(Expr::Literal(gt_value.into())), - ) - }), - Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, value).0; - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(&table_name), - ast::Name::from_str(&column.name), - )), - ast::Operator::Less, - Box::new(Expr::Literal(lt_value.into())), - ) - }), - ], - rng, - ); - Predicate(expr) - } -} - -impl SimplePredicate { - /// Generates a true [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table - pub fn true_binary( - rng: &mut R, - table: &T, - row: &[SimValue], - ) -> Self { - // Pick a random column - let columns = table.columns().collect::>(); - let column_index = rng.gen_range(0..columns.len()); - let column = columns[column_index]; - let column_value = &row[column_index]; - let table_name = column.table_name; - // Avoid creation of NULLs - if row.is_empty() { - return SimplePredicate(Predicate(Expr::Literal(SimValue::TRUE.into()))); - } - - let expr = one_of( - vec![ - Box::new(|_rng| { - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), - )), - ast::Operator::Equals, - Box::new(Expr::Literal(column_value.into())), - ) - }), - Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, column_value).0; - Expr::Binary( - Box::new(Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), - )), - ast::Operator::Greater, - Box::new(Expr::Literal(lt_value.into())), - ) - }), - Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, column_value).0; - Expr::Binary( - Box::new(Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), - )), - ast::Operator::Less, - Box::new(Expr::Literal(gt_value.into())), - ) - }), - ], - rng, - ); - SimplePredicate(Predicate(expr)) - } - - /// Generates a false [ast::Expr::Binary] [SimplePredicate] from a [TableContext] for a row in the table - pub fn false_binary( - rng: &mut R, - table: &T, - row: &[SimValue], - ) -> Self { - let columns = table.columns().collect::>(); - // Pick a random column - let column_index = rng.gen_range(0..columns.len()); - let column = columns[column_index]; - let column_value = &row[column_index]; - let table_name = column.table_name; - // Avoid creation of NULLs - if row.is_empty() { - return SimplePredicate(Predicate(Expr::Literal(SimValue::FALSE.into()))); - } - - let expr = one_of( - vec![ - Box::new(|_rng| { - Expr::Binary( - Box::new(Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), - )), - ast::Operator::NotEquals, - Box::new(Expr::Literal(column_value.into())), - ) - }), - Box::new(|rng| { - let gt_value = GTValue::arbitrary_from(rng, column_value).0; - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), - )), - ast::Operator::Greater, - Box::new(Expr::Literal(gt_value.into())), - ) - }), - Box::new(|rng| { - let lt_value = LTValue::arbitrary_from(rng, column_value).0; - Expr::Binary( - Box::new(ast::Expr::Qualified( - ast::Name::from_str(table_name), - ast::Name::from_str(&column.column.name), - )), - ast::Operator::Less, - Box::new(Expr::Literal(lt_value.into())), - ) - }), - ], - rng, - ); - SimplePredicate(Predicate(expr)) - } -} - -impl CompoundPredicate { - /// Decide if you want to create an AND or an OR - /// - /// Creates a Compound Predicate that is TRUE or FALSE for at least a single row - pub fn from_table_binary( - rng: &mut R, - table: &T, - predicate_value: bool, - ) -> Self { - // Cannot pick a row if the table is empty - let rows = table.rows(); - if rows.is_empty() { - return Self(if predicate_value { - Predicate::true_() - } else { - Predicate::false_() - }); - } - let row = pick(rows, rng); - - let predicate = if rng.gen_bool(0.7) { - // An AND for true requires each of its children to be true - // An AND for false requires at least one of its children to be false - if predicate_value { - (0..rng.gen_range(1..=3)) - .map(|_| SimplePredicate::arbitrary_from(rng, (table, row, true)).0) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::And, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::true_()) - } else { - // Create a vector of random booleans - let mut booleans = (0..rng.gen_range(1..=3)) - .map(|_| rng.gen_bool(0.5)) - .collect::>(); - - let len = booleans.len(); - - // Make sure at least one of them is false - if booleans.iter().all(|b| *b) { - booleans[rng.gen_range(0..len)] = false; - } - - booleans - .iter() - .map(|b| SimplePredicate::arbitrary_from(rng, (table, row, *b)).0) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::And, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::false_()) - } - } else { - // An OR for true requires at least one of its children to be true - // An OR for false requires each of its children to be false - if predicate_value { - // Create a vector of random booleans - let mut booleans = (0..rng.gen_range(1..=3)) - .map(|_| rng.gen_bool(0.5)) - .collect::>(); - let len = booleans.len(); - // Make sure at least one of them is true - if booleans.iter().all(|b| !*b) { - booleans[rng.gen_range(0..len)] = true; - } - - booleans - .iter() - .map(|b| SimplePredicate::arbitrary_from(rng, (table, row, *b)).0) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::Or, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::true_()) - } else { - (0..rng.gen_range(1..=3)) - .map(|_| SimplePredicate::arbitrary_from(rng, (table, row, false)).0) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::Or, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::false_()) - } - }; - Self(predicate) - } -} - -#[cfg(test)] -mod tests { - use rand::{Rng as _, SeedableRng as _}; - use rand_chacha::ChaCha8Rng; - - use crate::{ - generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _}, - model::{ - query::predicate::{expr_to_value, Predicate}, - table::{SimValue, Table}, - }, - }; - - fn get_seed() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - } - - #[test] - fn fuzz_true_binary_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - let row = pick(&values, &mut rng); - let predicate = Predicate::true_binary(&mut rng, &table, row); - let value = expr_to_value(&predicate.0, row, &table); - assert!( - value.as_ref().is_some_and(|value| value.as_bool()), - "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" - ) - } - } - - #[test] - fn fuzz_false_binary_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - let row = pick(&values, &mut rng); - let predicate = Predicate::false_binary(&mut rng, &table, row); - let value = expr_to_value(&predicate.0, row, &table); - assert!( - !value.as_ref().is_some_and(|value| value.as_bool()), - "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" - ) - } - } - - #[test] - fn fuzz_true_binary_simple_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - table.rows.extend(values.clone()); - let row = pick(&table.rows, &mut rng); - let predicate = SimplePredicate::true_binary(&mut rng, &table, row); - let result = values - .iter() - .map(|row| predicate.0.test(row, &table)) - .reduce(|accum, curr| accum || curr) - .unwrap_or(false); - assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") - } - } - - #[test] - fn fuzz_false_binary_simple_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - table.rows.extend(values.clone()); - let row = pick(&table.rows, &mut rng); - let predicate = SimplePredicate::false_binary(&mut rng, &table, row); - let result = values - .iter() - .map(|row| predicate.0.test(row, &table)) - .any(|res| !res); - assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") - } - } -} diff --git a/simulator/generation/predicate/mod.rs b/simulator/generation/predicate/mod.rs deleted file mode 100644 index 5c5887818..000000000 --- a/simulator/generation/predicate/mod.rs +++ /dev/null @@ -1,378 +0,0 @@ -use rand::{seq::SliceRandom as _, Rng}; -use turso_sqlite3_parser::ast::{self, Expr}; - -use crate::model::{ - query::predicate::Predicate, - table::{SimValue, Table, TableContext}, -}; - -use super::{one_of, ArbitraryFrom}; - -mod binary; -mod unary; - -#[derive(Debug)] -struct CompoundPredicate(Predicate); - -#[derive(Debug)] -struct SimplePredicate(Predicate); - -impl, T: TableContext> ArbitraryFrom<(&T, A, bool)> for SimplePredicate { - fn arbitrary_from(rng: &mut R, (table, row, predicate_value): (&T, A, bool)) -> Self { - let row = row.as_ref(); - // Pick an operator - let choice = rng.gen_range(0..2); - // Pick an operator - match predicate_value { - true => match choice { - 0 => SimplePredicate::true_binary(rng, table, row), - 1 => SimplePredicate::true_unary(rng, table, row), - _ => unreachable!(), - }, - false => match choice { - 0 => SimplePredicate::false_binary(rng, table, row), - 1 => SimplePredicate::false_unary(rng, table, row), - _ => unreachable!(), - }, - } - } -} - -impl ArbitraryFrom<(&T, bool)> for CompoundPredicate { - fn arbitrary_from(rng: &mut R, (table, predicate_value): (&T, bool)) -> Self { - CompoundPredicate::from_table_binary(rng, table, predicate_value) - } -} - -impl ArbitraryFrom<&T> for Predicate { - fn arbitrary_from(rng: &mut R, table: &T) -> Self { - let predicate_value = rng.gen_bool(0.5); - Predicate::arbitrary_from(rng, (table, predicate_value)).parens() - } -} - -impl ArbitraryFrom<(&T, bool)> for Predicate { - fn arbitrary_from(rng: &mut R, (table, predicate_value): (&T, bool)) -> Self { - CompoundPredicate::arbitrary_from(rng, (table, predicate_value)).0 - } -} - -impl ArbitraryFrom<(&str, &SimValue)> for Predicate { - fn arbitrary_from(rng: &mut R, (column_name, value): (&str, &SimValue)) -> Self { - Predicate::from_column_binary(rng, column_name, value) - } -} - -impl ArbitraryFrom<(&Table, &Vec)> for Predicate { - fn arbitrary_from(rng: &mut R, (t, row): (&Table, &Vec)) -> Self { - // We want to produce a predicate that is true for the row - // We can do this by creating several predicates that - // are true, some that are false, combiend them in ways that correspond to the creation of a true predicate - - // Produce some true and false predicates - let mut true_predicates = (1..=rng.gen_range(1..=4)) - .map(|_| Predicate::true_binary(rng, t, row)) - .collect::>(); - - let false_predicates = (0..=rng.gen_range(0..=3)) - .map(|_| Predicate::false_binary(rng, t, row)) - .collect::>(); - - // Start building a top level predicate from a true predicate - let mut result = true_predicates.pop().unwrap(); - - let mut predicates = true_predicates - .iter() - .map(|p| (true, p.clone())) - .chain(false_predicates.iter().map(|p| (false, p.clone()))) - .collect::>(); - - predicates.shuffle(rng); - - while !predicates.is_empty() { - // Create a new predicate from at least 1 and at most 3 predicates - let context = - predicates[0..rng.gen_range(0..=usize::min(3, predicates.len()))].to_vec(); - // Shift `predicates` to remove the predicates in the context - predicates = predicates[context.len()..].to_vec(); - - // `result` is true, so we have the following three options to make a true predicate: - // T or F - // T or T - // T and T - - result = one_of( - vec![ - // T or (X1 or X2 or ... or Xn) - Box::new(|_| { - Predicate(Expr::Binary( - Box::new(result.0.clone()), - ast::Operator::Or, - Box::new( - context - .iter() - .map(|(_, p)| p.clone()) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::Or, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::false_()) - .0, - ), - )) - }), - // T or (T1 and T2 and ... and Tn) - Box::new(|_| { - Predicate(Expr::Binary( - Box::new(result.0.clone()), - ast::Operator::Or, - Box::new( - context - .iter() - .map(|(_, p)| p.clone()) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::And, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::true_()) - .0, - ), - )) - }), - // T and T - Box::new(|_| { - // Check if all the predicates in the context are true - if context.iter().all(|(b, _)| *b) { - // T and (X1 or X2 or ... or Xn) - Predicate(Expr::Binary( - Box::new(result.0.clone()), - ast::Operator::And, - Box::new( - context - .iter() - .map(|(_, p)| p.clone()) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::And, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::true_()) - .0, - ), - )) - } - // Check if there is at least one true predicate - else if context.iter().any(|(b, _)| *b) { - // T and (X1 or X2 or ... or Xn) - Predicate(Expr::Binary( - Box::new(result.0.clone()), - ast::Operator::And, - Box::new( - context - .iter() - .map(|(_, p)| p.clone()) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::Or, - Box::new(curr.0), - )) - }) - .unwrap_or(Predicate::false_()) - .0, - ), - )) - // Predicate::And(vec![ - // result.clone(), - // Predicate::Or(context.iter().map(|(_, p)| p.clone()).collect()), - // ]) - } else { - // T and (X1 or X2 or ... or Xn or TRUE) - Predicate(Expr::Binary( - Box::new(result.0.clone()), - ast::Operator::And, - Box::new( - context - .iter() - .map(|(_, p)| p.clone()) - .chain(std::iter::once(Predicate::true_())) - .reduce(|accum, curr| { - Predicate(Expr::Binary( - Box::new(accum.0), - ast::Operator::Or, - Box::new(curr.0), - )) - }) - .unwrap() // Chain guarantees at least one value - .0, - ), - )) - } - }), - ], - rng, - ); - } - result - } -} - -#[cfg(test)] -mod tests { - use rand::{Rng as _, SeedableRng as _}; - use rand_chacha::ChaCha8Rng; - - use crate::{ - generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _}, - model::{ - query::predicate::{expr_to_value, Predicate}, - table::{SimValue, Table}, - }, - }; - - fn get_seed() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - } - - #[test] - fn fuzz_arbitrary_table_true_simple_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - let row = pick(&values, &mut rng); - let predicate = SimplePredicate::arbitrary_from(&mut rng, (&table, row, true)).0; - let value = expr_to_value(&predicate.0, row, &table); - assert!( - value.as_ref().is_some_and(|value| value.as_bool()), - "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" - ) - } - } - - #[test] - fn fuzz_arbitrary_table_false_simple_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - let row = pick(&values, &mut rng); - let predicate = SimplePredicate::arbitrary_from(&mut rng, (&table, row, false)).0; - let value = expr_to_value(&predicate.0, row, &table); - assert!( - !value.as_ref().is_some_and(|value| value.as_bool()), - "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" - ) - } - } - - #[test] - fn fuzz_arbitrary_row_table_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - let row = pick(&values, &mut rng); - let predicate = Predicate::arbitrary_from(&mut rng, (&table, row)); - let value = expr_to_value(&predicate.0, row, &table); - assert!( - value.as_ref().is_some_and(|value| value.as_bool()), - "Predicate: {predicate:#?}\nValue: {value:#?}\nSeed: {seed}" - ) - } - } - - #[test] - fn fuzz_arbitrary_true_table_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - table.rows.extend(values.clone()); - let predicate = Predicate::arbitrary_from(&mut rng, (&table, true)); - let result = values - .iter() - .map(|row| predicate.test(row, &table)) - .reduce(|accum, curr| accum || curr) - .unwrap_or(false); - assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") - } - } - - #[test] - fn fuzz_arbitrary_false_table_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - table.rows.extend(values.clone()); - let predicate = Predicate::arbitrary_from(&mut rng, (&table, false)); - let result = values - .iter() - .map(|row| predicate.test(row, &table)) - .any(|res| !res); - assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") - } - } -} diff --git a/simulator/generation/predicate/unary.rs b/simulator/generation/predicate/unary.rs deleted file mode 100644 index f7f374b6e..000000000 --- a/simulator/generation/predicate/unary.rs +++ /dev/null @@ -1,306 +0,0 @@ -//! Contains code regarding generation for [ast::Expr::Unary] Predicate -//! TODO: for now just generating [ast::Literal], but want to also generate Columns and any -//! arbitrary [ast::Expr] - -use turso_sqlite3_parser::ast::{self, Expr}; - -use crate::{ - generation::{backtrack, pick, predicate::SimplePredicate, ArbitraryFromMaybe}, - model::{ - query::predicate::Predicate, - table::{SimValue, TableContext}, - }, -}; - -pub struct TrueValue(pub SimValue); - -impl ArbitraryFromMaybe<&SimValue> for TrueValue { - fn arbitrary_from_maybe(_rng: &mut R, value: &SimValue) -> Option - where - Self: Sized, - { - // If the Value is a true value return it else you cannot return a true Value - value.as_bool().then_some(Self(value.clone())) - } -} - -impl ArbitraryFromMaybe<&Vec<&SimValue>> for TrueValue { - fn arbitrary_from_maybe(rng: &mut R, values: &Vec<&SimValue>) -> Option - where - Self: Sized, - { - if values.is_empty() { - return Some(Self(SimValue::TRUE)); - } - - let value = pick(values, rng); - Self::arbitrary_from_maybe(rng, *value) - } -} - -pub struct FalseValue(pub SimValue); - -impl ArbitraryFromMaybe<&SimValue> for FalseValue { - fn arbitrary_from_maybe(_rng: &mut R, value: &SimValue) -> Option - where - Self: Sized, - { - // If the Value is a false value return it else you cannot return a false Value - (!value.as_bool()).then_some(Self(value.clone())) - } -} - -impl ArbitraryFromMaybe<&Vec<&SimValue>> for FalseValue { - fn arbitrary_from_maybe(rng: &mut R, values: &Vec<&SimValue>) -> Option - where - Self: Sized, - { - if values.is_empty() { - return Some(Self(SimValue::FALSE)); - } - - let value = pick(values, rng); - Self::arbitrary_from_maybe(rng, *value) - } -} - -pub struct BitNotValue(pub SimValue); - -impl ArbitraryFromMaybe<(&SimValue, bool)> for BitNotValue { - fn arbitrary_from_maybe( - _rng: &mut R, - (value, predicate): (&SimValue, bool), - ) -> Option - where - Self: Sized, - { - let bit_not_val = value.unary_exec(ast::UnaryOperator::BitwiseNot); - // If you bit not the Value and it meets the predicate return Some, else None - (bit_not_val.as_bool() == predicate).then_some(BitNotValue(value.clone())) - } -} - -impl ArbitraryFromMaybe<(&Vec<&SimValue>, bool)> for BitNotValue { - fn arbitrary_from_maybe( - rng: &mut R, - (values, predicate): (&Vec<&SimValue>, bool), - ) -> Option - where - Self: Sized, - { - if values.is_empty() { - return None; - } - - let value = pick(values, rng); - Self::arbitrary_from_maybe(rng, (*value, predicate)) - } -} - -// TODO: have some more complex generation with columns names here as well -impl SimplePredicate { - /// Generates a true [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for some values in the table - pub fn true_unary( - rng: &mut R, - table: &T, - row: &[SimValue], - ) -> Self { - let columns = table.columns().collect::>(); - // Pick a random column - let column_index = rng.gen_range(0..columns.len()); - let column_value = &row[column_index]; - let num_retries = row.len(); - // Avoid creation of NULLs - if row.is_empty() { - return SimplePredicate(Predicate(Expr::Literal(SimValue::TRUE.into()))); - } - let expr = backtrack( - vec![ - ( - num_retries, - Box::new(|rng| { - TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| { - assert!(value.0.as_bool()); - // Positive is a no-op in Sqlite - Expr::unary(ast::UnaryOperator::Positive, Expr::Literal(value.0.into())) - }) - }), - ), - // ( - // num_retries, - // Box::new(|rng| { - // TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| { - // assert!(value.0.as_bool()); - // // True Value with negative is still True - // Expr::unary(ast::UnaryOperator::Negative, Expr::Literal(value.0.into())) - // }) - // }), - // ), - // ( - // num_retries, - // Box::new(|rng| { - // BitNotValue::arbitrary_from_maybe(rng, (column_value, true)).map(|value| { - // Expr::unary( - // ast::UnaryOperator::BitwiseNot, - // Expr::Literal(value.0.into()), - // ) - // }) - // }), - // ), - ( - num_retries, - Box::new(|rng| { - FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| { - assert!(!value.0.as_bool()); - Expr::unary(ast::UnaryOperator::Not, Expr::Literal(value.0.into())) - }) - }), - ), - ], - rng, - ); - // If cannot generate a value - SimplePredicate(Predicate( - expr.unwrap_or(Expr::Literal(SimValue::TRUE.into())), - )) - } - - /// Generates a false [ast::Expr::Unary] [SimplePredicate] from a [TableContext] for a row in the table - pub fn false_unary( - rng: &mut R, - table: &T, - row: &[SimValue], - ) -> Self { - let columns = table.columns().collect::>(); - // Pick a random column - let column_index = rng.gen_range(0..columns.len()); - let column_value = &row[column_index]; - let num_retries = row.len(); - // Avoid creation of NULLs - if row.is_empty() { - return SimplePredicate(Predicate(Expr::Literal(SimValue::FALSE.into()))); - } - let expr = backtrack( - vec![ - // ( - // num_retries, - // Box::new(|rng| { - // FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| { - // assert!(!value.0.as_bool()); - // // Positive is a no-op in Sqlite - // Expr::unary(ast::UnaryOperator::Positive, Expr::Literal(value.0.into())) - // }) - // }), - // ), - // ( - // num_retries, - // Box::new(|rng| { - // FalseValue::arbitrary_from_maybe(rng, column_value).map(|value| { - // assert!(!value.0.as_bool()); - // // True Value with negative is still True - // Expr::unary(ast::UnaryOperator::Negative, Expr::Literal(value.0.into())) - // }) - // }), - // ), - // ( - // num_retries, - // Box::new(|rng| { - // BitNotValue::arbitrary_from_maybe(rng, (column_value, false)).map(|value| { - // Expr::unary( - // ast::UnaryOperator::BitwiseNot, - // Expr::Literal(value.0.into()), - // ) - // }) - // }), - // ), - ( - num_retries, - Box::new(|rng| { - TrueValue::arbitrary_from_maybe(rng, column_value).map(|value| { - assert!(value.0.as_bool()); - Expr::unary(ast::UnaryOperator::Not, Expr::Literal(value.0.into())) - }) - }), - ), - ], - rng, - ); - // If cannot generate a value - SimplePredicate(Predicate( - expr.unwrap_or(Expr::Literal(SimValue::FALSE.into())), - )) - } -} - -#[cfg(test)] -mod tests { - use rand::{Rng as _, SeedableRng as _}; - use rand_chacha::ChaCha8Rng; - - use crate::{ - generation::{pick, predicate::SimplePredicate, Arbitrary, ArbitraryFrom as _}, - model::table::{SimValue, Table}, - }; - - fn get_seed() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() - } - - #[test] - fn fuzz_true_unary_simple_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - table.rows.extend(values.clone()); - let row = pick(&table.rows, &mut rng); - let predicate = SimplePredicate::true_unary(&mut rng, &table, row); - let result = values - .iter() - .map(|row| predicate.0.test(row, &table)) - .reduce(|accum, curr| accum || curr) - .unwrap_or(false); - assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") - } - } - - #[test] - fn fuzz_false_unary_simple_predicate() { - let seed = get_seed(); - let mut rng = ChaCha8Rng::seed_from_u64(seed); - for _ in 0..10000 { - let mut table = Table::arbitrary(&mut rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(&mut rng, &c.column_type)) - .collect() - }) - .collect(); - table.rows.extend(values.clone()); - let row = pick(&table.rows, &mut rng); - let predicate = SimplePredicate::false_unary(&mut rng, &table, row); - let result = values - .iter() - .map(|row| predicate.0.test(row, &table)) - .any(|res| !res); - assert!(result, "Predicate: {predicate:#?}\nSeed: {seed}") - } - } -} diff --git a/simulator/generation/property.rs b/simulator/generation/property.rs index 4725aa384..288c4e75d 100644 --- a/simulator/generation/property.rs +++ b/simulator/generation/property.rs @@ -1,30 +1,23 @@ use serde::{Deserialize, Serialize}; -use turso_core::{types, LimboError}; -use turso_sqlite3_parser::ast::{self}; - -use crate::{ - generation::Shadow as _, +use sql_generation::{ + generation::{frequency, pick, pick_index, ArbitraryFrom}, model::{ query::{ predicate::Predicate, - select::{ - CompoundOperator, CompoundSelect, Distinctness, ResultColumn, SelectBody, - SelectInner, - }, + select::{CompoundOperator, CompoundSelect, ResultColumn, SelectBody, SelectInner}, transaction::{Begin, Commit, Rollback}, update::Update, - Create, Delete, Drop, Insert, Query, Select, + Create, Delete, Drop, Insert, Select, }, table::SimValue, }, - runner::env::SimulatorEnv, }; +use turso_core::{types, LimboError}; +use turso_parser::ast::{self, Distinctness}; -use super::{ - frequency, pick, pick_index, - plan::{Assertion, Interaction, InteractionStats, ResultSet}, - ArbitraryFrom, -}; +use crate::{generation::Shadow as _, model::Query, runner::env::SimulatorEnv}; + +use super::plan::{Assertion, Interaction, InteractionStats, ResultSet}; /// Properties are representations of executable specifications /// about the database behavior. @@ -1073,7 +1066,7 @@ fn property_insert_values_select( // Get a random table let table = pick(&env.tables, rng); // Generate rows to insert - let rows = (0..rng.gen_range(1..=5)) + let rows = (0..rng.random_range(1..=5)) .map(|_| Vec::::arbitrary_from(rng, table)) .collect::>(); @@ -1088,10 +1081,10 @@ fn property_insert_values_select( }; // Choose if we want queries to be executed in an interactive transaction - let interactive = if rng.gen_bool(0.5) { + let interactive = if rng.random_bool(0.5) { Some(InteractiveQueryInfo { - start_with_immediate: rng.gen_bool(0.5), - end_with_commit: rng.gen_bool(0.5), + start_with_immediate: rng.random_bool(0.5), + end_with_commit: rng.random_bool(0.5), }) } else { None @@ -1107,7 +1100,7 @@ fn property_insert_values_select( immediate: interactive.start_with_immediate, })); } - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); match &query { Query::Delete(Delete { @@ -1198,7 +1191,7 @@ fn property_select_limit(rng: &mut R, env: &SimulatorEnv) -> Prope table.name.clone(), vec![ResultColumn::Star], Predicate::arbitrary_from(rng, table), - Some(rng.gen_range(1..=5)), + Some(rng.random_range(1..=5)), Distinctness::All, ); Property::SelectLimit { select } @@ -1221,7 +1214,7 @@ fn property_double_create_failure( // The interactions in the middle has the following constraints; // - [x] There will be no errors in the middle interactions.(best effort) // - [ ] Table `t` will not be renamed or dropped.(todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); if let Query::Create(Create { table: t }) = &query { // There will be no errors in the middle interactions. @@ -1254,7 +1247,7 @@ fn property_delete_select( // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) // - [x] A row that holds for the predicate will not be inserted. // - [ ] The table `t` will not be renamed, dropped, or altered. (todo: add this constraint once ALTER or DROP is implemented) - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); match &query { Query::Insert(Insert::Values { table: t, values }) => { @@ -1309,7 +1302,7 @@ fn property_drop_select( let mut queries = Vec::new(); // - [x] There will be no errors in the middle interactions. (this constraint is impossible to check, so this is just best effort) // - [-] The table `t` will not be created, no table will be renamed to `t`. (todo: update this constraint once ALTER is implemented) - for _ in 0..rng.gen_range(0..3) { + for _ in 0..rng.random_range(0..3) { let query = Query::arbitrary_from(rng, (env, remaining)); if let Query::Create(Create { table: t }) = &query { // - The table `t` will not be created diff --git a/simulator/generation/query.rs b/simulator/generation/query.rs index 1abd41b1b..bb1344c2a 100644 --- a/simulator/generation/query.rs +++ b/simulator/generation/query.rs @@ -1,330 +1,11 @@ -use crate::generation::{Arbitrary, ArbitraryFrom, ArbitrarySizedFrom, Shadow}; -use crate::model::query::predicate::Predicate; -use crate::model::query::select::{ - CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, - SelectInner, -}; -use crate::model::query::update::Update; -use crate::model::query::{Create, Delete, Drop, Insert, Query, Select}; -use crate::model::table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}; -use crate::SimulatorEnv; -use itertools::Itertools; +use crate::{model::Query, SimulatorEnv}; use rand::Rng; -use turso_sqlite3_parser::ast::{Expr, SortOrder}; +use sql_generation::{ + generation::{frequency, Arbitrary, ArbitraryFrom}, + model::query::{update::Update, Create, Delete, Insert, Select}, +}; use super::property::Remaining; -use super::{backtrack, frequency, pick}; - -impl Arbitrary for Create { - fn arbitrary(rng: &mut R) -> Self { - Create { - table: Table::arbitrary(rng), - } - } -} - -impl ArbitraryFrom<&Vec
> for FromClause { - fn arbitrary_from(rng: &mut R, tables: &Vec
) -> Self { - let num_joins = match rng.gen_range(0..=100) { - 0..=90 => 0, - 91..=97 => 1, - 98..=100 => 2, - _ => unreachable!(), - }; - - let mut tables = tables.clone(); - let mut table = pick(&tables, rng).clone(); - - tables.retain(|t| t.name != table.name); - - let name = table.name.clone(); - - let mut table_context = JoinTable { - tables: Vec::new(), - rows: Vec::new(), - }; - - let joins: Vec<_> = (0..num_joins) - .filter_map(|_| { - if tables.is_empty() { - return None; - } - let join_table = pick(&tables, rng).clone(); - let joined_table_name = join_table.name.clone(); - - tables.retain(|t| t.name != join_table.name); - table_context.rows = table_context - .rows - .iter() - .cartesian_product(join_table.rows.iter()) - .map(|(t_row, j_row)| { - let mut row = t_row.clone(); - row.extend(j_row.clone()); - row - }) - .collect(); - // TODO: inneficient. use a Deque to push_front? - table_context.tables.insert(0, join_table); - for row in &mut table.rows { - assert_eq!( - row.len(), - table.columns.len(), - "Row length does not match column length after join" - ); - } - - let predicate = Predicate::arbitrary_from(rng, &table); - Some(JoinedTable { - table: joined_table_name, - join_type: JoinType::Inner, - on: predicate, - }) - }) - .collect(); - FromClause { table: name, joins } - } -} - -impl ArbitraryFrom<&SimulatorEnv> for SelectInner { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let from = FromClause::arbitrary_from(rng, &env.tables); - let mut tables = env.tables.clone(); - // todo: this is a temporary hack because env is not separated from the tables - let join_table = from - .shadow(&mut tables) - .expect("Failed to shadow FromClause"); - let cuml_col_count = join_table.columns().count(); - - let order_by = 'order_by: { - if rng.gen_bool(0.3) { - let order_by_table_candidates = from - .joins - .iter() - .map(|j| j.table.clone()) - .chain(std::iter::once(from.table.clone())) - .collect::>(); - let order_by_col_count = - (rng.gen::() * rng.gen::() * (cuml_col_count as f64)) as usize; // skew towards 0 - if order_by_col_count == 0 { - break 'order_by None; - } - let mut col_names = std::collections::HashSet::new(); - let mut order_by_cols = Vec::new(); - while order_by_cols.len() < order_by_col_count { - let table = pick(&order_by_table_candidates, rng); - let table = tables.iter().find(|t| t.name == *table).unwrap(); - let col = pick(&table.columns, rng); - let col_name = format!("{}.{}", table.name, col.name); - if col_names.insert(col_name.clone()) { - order_by_cols.push(( - col_name, - if rng.gen_bool(0.5) { - SortOrder::Asc - } else { - SortOrder::Desc - }, - )); - } - } - Some(OrderBy { - columns: order_by_cols, - }) - } else { - None - } - }; - - SelectInner { - distinctness: if env.opts.experimental_indexes { - Distinctness::arbitrary(rng) - } else { - Distinctness::All - }, - columns: vec![ResultColumn::Star], - from: Some(from), - where_clause: Predicate::arbitrary_from(rng, &join_table), - order_by, - } - } -} - -impl ArbitrarySizedFrom<&SimulatorEnv> for SelectInner { - fn arbitrary_sized_from( - rng: &mut R, - env: &SimulatorEnv, - num_result_columns: usize, - ) -> Self { - let mut select_inner = SelectInner::arbitrary_from(rng, env); - let select_from = &select_inner.from.as_ref().unwrap(); - let table_names = select_from - .joins - .iter() - .map(|j| j.table.clone()) - .chain(std::iter::once(select_from.table.clone())) - .collect::>(); - - let flat_columns_names = table_names - .iter() - .flat_map(|t| { - env.tables - .iter() - .find(|table| table.name == *t) - .unwrap() - .columns - .iter() - .map(|c| format!("{}.{}", t.clone(), c.name)) - }) - .collect::>(); - let selected_columns = pick_unique(&flat_columns_names, num_result_columns, rng); - let mut columns = Vec::new(); - for column_name in selected_columns { - columns.push(ResultColumn::Column(column_name.clone())); - } - select_inner.columns = columns; - select_inner - } -} - -impl Arbitrary for Distinctness { - fn arbitrary(rng: &mut R) -> Self { - match rng.gen_range(0..=5) { - 0..4 => Distinctness::All, - _ => Distinctness::Distinct, - } - } -} -impl Arbitrary for CompoundOperator { - fn arbitrary(rng: &mut R) -> Self { - match rng.gen_range(0..=1) { - 0 => CompoundOperator::Union, - 1 => CompoundOperator::UnionAll, - _ => unreachable!(), - } - } -} - -/// SelectFree is a wrapper around Select that allows for arbitrary generation -/// of selects without requiring a specific environment, which is useful for generating -/// arbitrary expressions without referring to the tables. -pub(crate) struct SelectFree(pub(crate) Select); - -impl ArbitraryFrom<&SimulatorEnv> for SelectFree { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let expr = Predicate(Expr::arbitrary_sized_from(rng, env, 8)); - let select = Select::expr(expr); - Self(select) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Select { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - // Generate a number of selects based on the query size - // If experimental indexes are enabled, we can have selects with compounds - // Otherwise, we just have a single select with no compounds - let num_compound_selects = if env.opts.experimental_indexes { - match rng.gen_range(0..=100) { - 0..=95 => 0, - 96..=99 => 1, - 100 => 2, - _ => unreachable!(), - } - } else { - 0 - }; - - let min_column_count_across_tables = - env.tables.iter().map(|t| t.columns.len()).min().unwrap(); - - let num_result_columns = rng.gen_range(1..=min_column_count_across_tables); - - let mut first = SelectInner::arbitrary_sized_from(rng, env, num_result_columns); - - let mut rest: Vec = (0..num_compound_selects) - .map(|_| SelectInner::arbitrary_sized_from(rng, env, num_result_columns)) - .collect(); - - if !rest.is_empty() { - // ORDER BY is not supported in compound selects yet - first.order_by = None; - for s in &mut rest { - s.order_by = None; - } - } - - Self { - body: SelectBody { - select: Box::new(first), - compounds: rest - .into_iter() - .map(|s| CompoundSelect { - operator: CompoundOperator::arbitrary(rng), - select: Box::new(s), - }) - .collect(), - }, - limit: None, - } - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Insert { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let gen_values = |rng: &mut R| { - let table = pick(&env.tables, rng); - let num_rows = rng.gen_range(1..10); - let values: Vec> = (0..num_rows) - .map(|_| { - table - .columns - .iter() - .map(|c| SimValue::arbitrary_from(rng, &c.column_type)) - .collect() - }) - .collect(); - Some(Insert::Values { - table: table.name.clone(), - values, - }) - }; - - let _gen_select = |rng: &mut R| { - // Find a non-empty table - let select_table = env.tables.iter().find(|t| !t.rows.is_empty())?; - let row = pick(&select_table.rows, rng); - let predicate = Predicate::arbitrary_from(rng, (select_table, row)); - // Pick another table to insert into - let select = Select::simple(select_table.name.clone(), predicate); - let table = pick(&env.tables, rng); - Some(Insert::Select { - table: table.name.clone(), - select: Box::new(select), - }) - }; - - // TODO: Add back gen_select when https://github.com/tursodatabase/turso/issues/2129 is fixed. - // Backtrack here cannot return None - backtrack(vec![(1, Box::new(gen_values))], rng).unwrap() - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Delete { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); - Self { - table: table.name.clone(), - predicate: Predicate::arbitrary_from(rng, table), - } - } -} - -impl ArbitraryFrom<&SimulatorEnv> for Drop { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); - Self { - table: table.name.clone(), - } - } -} impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query { fn arbitrary_from(rng: &mut R, (env, remaining): (&SimulatorEnv, &Remaining)) -> Self { @@ -355,43 +36,3 @@ impl ArbitraryFrom<(&SimulatorEnv, &Remaining)> for Query { ) } } - -fn pick_unique( - items: &[T], - count: usize, - rng: &mut impl rand::Rng, -) -> Vec -where - ::Owned: PartialEq, -{ - let mut picked: Vec = Vec::new(); - while picked.len() < count { - let item = pick(items, rng); - if !picked.contains(&item.to_owned()) { - picked.push(item.to_owned()); - } - } - picked -} - -impl ArbitraryFrom<&SimulatorEnv> for Update { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - let table = pick(&env.tables, rng); - let num_cols = rng.gen_range(1..=table.columns.len()); - let columns = pick_unique(&table.columns, num_cols, rng); - let set_values: Vec<(String, SimValue)> = columns - .iter() - .map(|column| { - ( - column.name.clone(), - SimValue::arbitrary_from(rng, &column.column_type), - ) - }) - .collect(); - Update { - table: table.name.clone(), - set_values, - predicate: Predicate::arbitrary_from(rng, table), - } - } -} diff --git a/simulator/generation/table.rs b/simulator/generation/table.rs deleted file mode 100644 index 66f48b5ad..000000000 --- a/simulator/generation/table.rs +++ /dev/null @@ -1,258 +0,0 @@ -use std::collections::HashSet; - -use rand::Rng; -use turso_core::Value; - -use crate::generation::{gen_random_text, pick, readable_name_custom, Arbitrary, ArbitraryFrom}; -use crate::model::table::{Column, ColumnType, Name, SimValue, Table}; - -use super::ArbitraryFromMaybe; - -impl Arbitrary for Name { - fn arbitrary(rng: &mut R) -> Self { - let name = readable_name_custom("_", rng); - Name(name.replace("-", "_")) - } -} - -impl Arbitrary for Table { - fn arbitrary(rng: &mut R) -> Self { - let name = Name::arbitrary(rng).0; - let columns = loop { - let large_table = rng.gen_bool(0.1); - let column_size = if large_table { - rng.gen_range(64..125) // todo: make this higher (128+) - } else { - rng.gen_range(1..=10) - }; - let columns = (1..=column_size) - .map(|_| Column::arbitrary(rng)) - .collect::>(); - // TODO: see if there is a better way to detect duplicates here - let mut set = HashSet::with_capacity(columns.len()); - set.extend(columns.iter()); - // Has repeated column name inside so generate again - if set.len() != columns.len() { - continue; - } - break columns; - }; - - Table { - rows: Vec::new(), - name, - columns, - indexes: vec![], - } - } -} - -impl Arbitrary for Column { - fn arbitrary(rng: &mut R) -> Self { - let name = Name::arbitrary(rng).0; - let column_type = ColumnType::arbitrary(rng); - Self { - name, - column_type, - primary: false, - unique: false, - } - } -} - -impl Arbitrary for ColumnType { - fn arbitrary(rng: &mut R) -> Self { - pick(&[Self::Integer, Self::Float, Self::Text, Self::Blob], rng).to_owned() - } -} - -impl ArbitraryFrom<&Table> for Vec { - fn arbitrary_from(rng: &mut R, table: &Table) -> Self { - let mut row = Vec::new(); - for column in table.columns.iter() { - let value = SimValue::arbitrary_from(rng, &column.column_type); - row.push(value); - } - row - } -} - -impl ArbitraryFrom<&Vec<&SimValue>> for SimValue { - fn arbitrary_from(rng: &mut R, values: &Vec<&Self>) -> Self { - if values.is_empty() { - return Self(Value::Null); - } - - pick(values, rng).to_owned().clone() - } -} - -impl ArbitraryFrom<&ColumnType> for SimValue { - fn arbitrary_from(rng: &mut R, column_type: &ColumnType) -> Self { - let value = match column_type { - ColumnType::Integer => Value::Integer(rng.gen_range(i64::MIN..i64::MAX)), - ColumnType::Float => Value::Float(rng.gen_range(-1e10..1e10)), - ColumnType::Text => Value::build_text(gen_random_text(rng)), - ColumnType::Blob => Value::Blob(gen_random_text(rng).as_bytes().to_vec()), - }; - SimValue(value) - } -} - -pub(crate) struct LTValue(pub(crate) SimValue); - -impl ArbitraryFrom<&Vec<&SimValue>> for LTValue { - fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { - if values.is_empty() { - return Self(SimValue(Value::Null)); - } - - // Get value less than all values - let value = Value::exec_min(values.iter().map(|value| &value.0)); - Self::arbitrary_from(rng, &SimValue(value)) - } -} - -impl ArbitraryFrom<&SimValue> for LTValue { - fn arbitrary_from(rng: &mut R, value: &SimValue) -> Self { - let new_value = match &value.0 { - Value::Integer(i) => Value::Integer(rng.gen_range(i64::MIN..*i - 1)), - Value::Float(f) => Value::Float(f - rng.gen_range(0.0..1e10)), - value @ Value::Text(..) => { - // Either shorten the string, or make at least one character smaller and mutate the rest - let mut t = value.to_string(); - if rng.gen_bool(0.01) { - t.pop(); - Value::build_text(t) - } else { - let mut t = t.chars().map(|c| c as u32).collect::>(); - let index = rng.gen_range(0..t.len()); - t[index] -= 1; - // Mutate the rest of the string - for val in t.iter_mut().skip(index + 1) { - *val = rng.gen_range('a' as u32..='z' as u32); - } - let t = t - .into_iter() - .map(|c| char::from_u32(c).unwrap_or('z')) - .collect::(); - Value::build_text(t) - } - } - Value::Blob(b) => { - // Either shorten the blob, or make at least one byte smaller and mutate the rest - let mut b = b.clone(); - if rng.gen_bool(0.01) { - b.pop(); - Value::Blob(b) - } else { - let index = rng.gen_range(0..b.len()); - b[index] -= 1; - // Mutate the rest of the blob - for val in b.iter_mut().skip(index + 1) { - *val = rng.gen_range(0..=255); - } - Value::Blob(b) - } - } - _ => unreachable!(), - }; - Self(SimValue(new_value)) - } -} - -pub(crate) struct GTValue(pub(crate) SimValue); - -impl ArbitraryFrom<&Vec<&SimValue>> for GTValue { - fn arbitrary_from(rng: &mut R, values: &Vec<&SimValue>) -> Self { - if values.is_empty() { - return Self(SimValue(Value::Null)); - } - // Get value greater than all values - let value = Value::exec_max(values.iter().map(|value| &value.0)); - - Self::arbitrary_from(rng, &SimValue(value)) - } -} - -impl ArbitraryFrom<&SimValue> for GTValue { - fn arbitrary_from(rng: &mut R, value: &SimValue) -> Self { - let new_value = match &value.0 { - Value::Integer(i) => Value::Integer(rng.gen_range(*i..i64::MAX)), - Value::Float(f) => Value::Float(rng.gen_range(*f..1e10)), - value @ Value::Text(..) => { - // Either lengthen the string, or make at least one character smaller and mutate the rest - let mut t = value.to_string(); - if rng.gen_bool(0.01) { - t.push(rng.gen_range(0..=255) as u8 as char); - Value::build_text(t) - } else { - let mut t = t.chars().map(|c| c as u32).collect::>(); - let index = rng.gen_range(0..t.len()); - t[index] += 1; - // Mutate the rest of the string - for val in t.iter_mut().skip(index + 1) { - *val = rng.gen_range('a' as u32..='z' as u32); - } - let t = t - .into_iter() - .map(|c| char::from_u32(c).unwrap_or('a')) - .collect::(); - Value::build_text(t) - } - } - Value::Blob(b) => { - // Either lengthen the blob, or make at least one byte smaller and mutate the rest - let mut b = b.clone(); - if rng.gen_bool(0.01) { - b.push(rng.gen_range(0..=255)); - Value::Blob(b) - } else { - let index = rng.gen_range(0..b.len()); - b[index] += 1; - // Mutate the rest of the blob - for val in b.iter_mut().skip(index + 1) { - *val = rng.gen_range(0..=255); - } - Value::Blob(b) - } - } - _ => unreachable!(), - }; - Self(SimValue(new_value)) - } -} - -pub(crate) struct LikeValue(pub(crate) SimValue); - -impl ArbitraryFromMaybe<&SimValue> for LikeValue { - fn arbitrary_from_maybe(rng: &mut R, value: &SimValue) -> Option { - match &value.0 { - value @ Value::Text(..) => { - let t = value.to_string(); - let mut t = t.chars().collect::>(); - // Remove a number of characters, either insert `_` for each character removed, or - // insert one `%` for the whole substring - let mut i = 0; - while i < t.len() { - if rng.gen_bool(0.1) { - t[i] = '_'; - } else if rng.gen_bool(0.05) { - t[i] = '%'; - // skip a list of characters - for _ in 0..rng.gen_range(0..=3.min(t.len() - i - 1)) { - t.remove(i + 1); - } - } - i += 1; - } - let index = rng.gen_range(0..t.len()); - t.insert(index, '%'); - Some(Self(SimValue(Value::build_text( - t.into_iter().collect::(), - )))) - } - _ => None, - } - } -} diff --git a/simulator/main.rs b/simulator/main.rs index d2a31f099..0aaaf8be7 100644 --- a/simulator/main.rs +++ b/simulator/main.rs @@ -2,7 +2,6 @@ use anyhow::anyhow; use clap::Parser; use generation::plan::{Interaction, InteractionPlan, InteractionPlanState}; -use generation::ArbitraryFrom; use notify::event::{DataChange, ModifyKind}; use notify::{EventKind, RecursiveMode, Watcher}; use rand::prelude::*; @@ -11,6 +10,7 @@ use runner::cli::{SimulatorCLI, SimulatorCommand}; use runner::env::SimulatorEnv; use runner::execution::{execute_plans, Execution, ExecutionHistory, ExecutionResult}; use runner::{differential, watch}; +use sql_generation::generation::ArbitraryFrom; use std::any::Any; use std::backtrace::Backtrace; use std::fs::OpenOptions; @@ -507,7 +507,7 @@ fn setup_simulation( (seed, env, plans) } else { let seed = cli_opts.seed.unwrap_or_else(|| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); rng.next_u64() }); tracing::info!("seed={}", seed); diff --git a/simulator/model/mod.rs b/simulator/model/mod.rs index e68355ee4..ce249baf5 100644 --- a/simulator/model/mod.rs +++ b/simulator/model/mod.rs @@ -1,4 +1,417 @@ -pub mod query; -pub mod table; +use std::{collections::HashSet, fmt::Display}; -pub(crate) const FAULT_ERROR_MSG: &str = "Injected fault"; +use anyhow::Context; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use sql_generation::model::{ + query::{ + select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, + transaction::{Begin, Commit, Rollback}, + update::Update, + Create, CreateIndex, Delete, Drop, EmptyContext, Insert, Select, + }, + table::{JoinTable, JoinType, SimValue, Table, TableContext}, +}; +use turso_parser::ast::{fmt::ToTokens, Distinctness}; + +use crate::{generation::Shadow, runner::env::SimulatorTables}; + +// This type represents the potential queries on the database. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Query { + Create(Create), + Select(Select), + Insert(Insert), + Delete(Delete), + Update(Update), + Drop(Drop), + CreateIndex(CreateIndex), + Begin(Begin), + Commit(Commit), + Rollback(Rollback), +} + +impl Query { + pub fn dependencies(&self) -> HashSet { + match self { + Query::Select(select) => select.dependencies(), + Query::Create(_) => HashSet::new(), + Query::Insert(Insert::Select { table, .. }) + | Query::Insert(Insert::Values { table, .. }) + | Query::Delete(Delete { table, .. }) + | Query::Update(Update { table, .. }) + | Query::Drop(Drop { table, .. }) => HashSet::from_iter([table.clone()]), + Query::CreateIndex(CreateIndex { table_name, .. }) => { + HashSet::from_iter([table_name.clone()]) + } + Query::Begin(_) | Query::Commit(_) | Query::Rollback(_) => HashSet::new(), + } + } + pub fn uses(&self) -> Vec { + match self { + Query::Create(Create { table }) => vec![table.name.clone()], + Query::Select(select) => select.dependencies().into_iter().collect(), + Query::Insert(Insert::Select { table, .. }) + | Query::Insert(Insert::Values { table, .. }) + | Query::Delete(Delete { table, .. }) + | Query::Update(Update { table, .. }) + | Query::Drop(Drop { table, .. }) => vec![table.clone()], + Query::CreateIndex(CreateIndex { table_name, .. }) => vec![table_name.clone()], + Query::Begin(..) | Query::Commit(..) | Query::Rollback(..) => vec![], + } + } +} + +impl Display for Query { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Create(create) => write!(f, "{create}"), + Self::Select(select) => write!(f, "{select}"), + Self::Insert(insert) => write!(f, "{insert}"), + Self::Delete(delete) => write!(f, "{delete}"), + Self::Update(update) => write!(f, "{update}"), + Self::Drop(drop) => write!(f, "{drop}"), + Self::CreateIndex(create_index) => write!(f, "{create_index}"), + Self::Begin(begin) => write!(f, "{begin}"), + Self::Commit(commit) => write!(f, "{commit}"), + Self::Rollback(rollback) => write!(f, "{rollback}"), + } + } +} + +impl Shadow for Query { + type Result = anyhow::Result>>; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + match self { + Query::Create(create) => create.shadow(env), + Query::Insert(insert) => insert.shadow(env), + Query::Delete(delete) => delete.shadow(env), + Query::Select(select) => select.shadow(env), + Query::Update(update) => update.shadow(env), + Query::Drop(drop) => drop.shadow(env), + Query::CreateIndex(create_index) => Ok(create_index.shadow(env)), + Query::Begin(begin) => Ok(begin.shadow(env)), + Query::Commit(commit) => Ok(commit.shadow(env)), + Query::Rollback(rollback) => Ok(rollback.shadow(env)), + } + } +} + +impl Shadow for Create { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if !tables.iter().any(|t| t.name == self.table.name) { + tables.push(self.table.clone()); + Ok(vec![]) + } else { + Err(anyhow::anyhow!( + "Table {} already exists. CREATE TABLE statement ignored.", + self.table.name + )) + } + } +} + +impl Shadow for CreateIndex { + type Result = Vec>; + fn shadow(&self, env: &mut SimulatorTables) -> Vec> { + env.tables + .iter_mut() + .find(|t| t.name == self.table_name) + .unwrap() + .indexes + .push(self.index_name.clone()); + vec![] + } +} + +impl Shadow for Delete { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + let table = tables.tables.iter_mut().find(|t| t.name == self.table); + + if let Some(table) = table { + // If the table exists, we can delete from it + let t2 = table.clone(); + table.rows.retain_mut(|r| !self.predicate.test(r, &t2)); + } else { + // If the table does not exist, we return an error + return Err(anyhow::anyhow!( + "Table {} does not exist. DELETE statement ignored.", + self.table + )); + } + + Ok(vec![]) + } +} + +impl Shadow for Drop { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if !tables.iter().any(|t| t.name == self.table) { + // If the table does not exist, we return an error + return Err(anyhow::anyhow!( + "Table {} does not exist. DROP statement ignored.", + self.table + )); + } + + tables.tables.retain(|t| t.name != self.table); + + Ok(vec![]) + } +} + +impl Shadow for Insert { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + match self { + Insert::Values { table, values } => { + if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { + t.rows.extend(values.clone()); + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. INSERT statement ignored.", + table + )); + } + } + Insert::Select { table, select } => { + let rows = select.shadow(tables)?; + if let Some(t) = tables.tables.iter_mut().find(|t| &t.name == table) { + t.rows.extend(rows); + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. INSERT statement ignored.", + table + )); + } + } + } + + Ok(vec![]) + } +} + +impl Shadow for FromClause { + type Result = anyhow::Result; + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + let tables = &mut env.tables; + + let first_table = tables + .iter() + .find(|t| t.name == self.table) + .context("Table not found")?; + + let mut join_table = JoinTable { + tables: vec![first_table.clone()], + rows: Vec::new(), + }; + + for join in &self.joins { + let joined_table = tables + .iter() + .find(|t| t.name == join.table) + .context("Joined table not found")?; + + join_table.tables.push(joined_table.clone()); + + match join.join_type { + JoinType::Inner => { + // Implement inner join logic + let join_rows = joined_table + .rows + .iter() + .filter(|row| join.on.test(row, joined_table)) + .cloned() + .collect::>(); + // take a cartesian product of the rows + let all_row_pairs = join_table + .rows + .clone() + .into_iter() + .cartesian_product(join_rows.iter()); + + for (row1, row2) in all_row_pairs { + let row = row1.iter().chain(row2.iter()).cloned().collect::>(); + + let is_in = join.on.test(&row, &join_table); + + if is_in { + join_table.rows.push(row); + } + } + } + _ => todo!(), + } + } + Ok(join_table) + } +} + +impl Shadow for SelectInner { + type Result = anyhow::Result; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + if let Some(from) = &self.from { + let mut join_table = from.shadow(env)?; + let col_count = join_table.columns().count(); + for row in &mut join_table.rows { + assert_eq!( + row.len(), + col_count, + "Row length does not match column length after join" + ); + } + let join_clone = join_table.clone(); + + join_table + .rows + .retain(|row| self.where_clause.test(row, &join_clone)); + + if self.distinctness == Distinctness::Distinct { + join_table.rows.sort_unstable(); + join_table.rows.dedup(); + } + + Ok(join_table) + } else { + assert!(self + .columns + .iter() + .all(|col| matches!(col, ResultColumn::Expr(_)))); + + // If `WHERE` is false, just return an empty table + if !self.where_clause.test(&[], &Table::anonymous(vec![])) { + return Ok(JoinTable { + tables: Vec::new(), + rows: Vec::new(), + }); + } + + // Compute the results of the column expressions and make a row + let mut row = Vec::new(); + for col in &self.columns { + match col { + ResultColumn::Expr(expr) => { + let value = expr.eval(&[], &Table::anonymous(vec![])); + if let Some(value) = value { + row.push(value); + } else { + return Err(anyhow::anyhow!( + "Failed to evaluate expression in free select ({})", + expr.0.format_with_context(&EmptyContext {}).unwrap() + )); + } + } + _ => unreachable!("Only expressions are allowed in free selects"), + } + } + + Ok(JoinTable { + tables: Vec::new(), + rows: vec![row], + }) + } + } +} + +impl Shadow for Select { + type Result = anyhow::Result>>; + + fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { + let first_result = self.body.select.shadow(env)?; + + let mut rows = first_result.rows; + + for compound in self.body.compounds.iter() { + let compound_results = compound.select.shadow(env)?; + + match compound.operator { + CompoundOperator::Union => { + // Union means we need to combine the results, removing duplicates + let mut new_rows = compound_results.rows; + new_rows.extend(rows.clone()); + new_rows.sort_unstable(); + new_rows.dedup(); + rows = new_rows; + } + CompoundOperator::UnionAll => { + // Union all means we just concatenate the results + rows.extend(compound_results.rows.into_iter()); + } + } + } + + Ok(rows) + } +} + +impl Shadow for Begin { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + tables.snapshot = Some(tables.tables.clone()); + vec![] + } +} + +impl Shadow for Commit { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + tables.snapshot = None; + vec![] + } +} + +impl Shadow for Rollback { + type Result = Vec>; + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + if let Some(tables_) = tables.snapshot.take() { + tables.tables = tables_; + } + vec![] + } +} + +impl Shadow for Update { + type Result = anyhow::Result>>; + + fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { + let table = tables.tables.iter_mut().find(|t| t.name == self.table); + + let table = if let Some(table) = table { + table + } else { + return Err(anyhow::anyhow!( + "Table {} does not exist. UPDATE statement ignored.", + self.table + )); + }; + + let t2 = table.clone(); + for row in table + .rows + .iter_mut() + .filter(|r| self.predicate.test(r, &t2)) + { + for (column, set_value) in &self.set_values { + if let Some((idx, _)) = table + .columns + .iter() + .enumerate() + .find(|(_, c)| &c.name == column) + { + row[idx] = set_value.clone(); + } + } + } + + Ok(vec![]) + } +} diff --git a/simulator/model/query/create.rs b/simulator/model/query/create.rs deleted file mode 100644 index ab0cd9789..000000000 --- a/simulator/model/query/create.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{ - generation::Shadow, - model::table::{SimValue, Table}, - runner::env::SimulatorTables, -}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Create { - pub(crate) table: Table, -} - -impl Shadow for Create { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if !tables.iter().any(|t| t.name == self.table.name) { - tables.push(self.table.clone()); - Ok(vec![]) - } else { - Err(anyhow::anyhow!( - "Table {} already exists. CREATE TABLE statement ignored.", - self.table.name - )) - } - } -} - -impl Display for Create { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "CREATE TABLE {} (", self.table.name)?; - - for (i, column) in self.table.columns.iter().enumerate() { - if i != 0 { - write!(f, ",")?; - } - write!(f, "{} {}", column.name, column.column_type)?; - } - - write!(f, ")") - } -} diff --git a/simulator/model/query/create_index.rs b/simulator/model/query/create_index.rs deleted file mode 100644 index 276396d4e..000000000 --- a/simulator/model/query/create_index.rs +++ /dev/null @@ -1,106 +0,0 @@ -use crate::{ - generation::{gen_random_text, pick, pick_n_unique, ArbitraryFrom, Shadow}, - model::table::SimValue, - runner::env::{SimulatorEnv, SimulatorTables}, -}; -use rand::Rng; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum SortOrder { - Asc, - Desc, -} - -impl std::fmt::Display for SortOrder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SortOrder::Asc => write!(f, "ASC"), - SortOrder::Desc => write!(f, "DESC"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(crate) struct CreateIndex { - pub(crate) index_name: String, - pub(crate) table_name: String, - pub(crate) columns: Vec<(String, SortOrder)>, -} - -impl Shadow for CreateIndex { - type Result = Vec>; - fn shadow(&self, env: &mut SimulatorTables) -> Vec> { - env.tables - .iter_mut() - .find(|t| t.name == self.table_name) - .unwrap() - .indexes - .push(self.index_name.clone()); - vec![] - } -} - -impl std::fmt::Display for CreateIndex { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "CREATE INDEX {} ON {} ({})", - self.index_name, - self.table_name, - self.columns - .iter() - .map(|(name, order)| format!("{name} {order}")) - .collect::>() - .join(", ") - ) - } -} - -impl ArbitraryFrom<&SimulatorEnv> for CreateIndex { - fn arbitrary_from(rng: &mut R, env: &SimulatorEnv) -> Self { - assert!( - !env.tables.is_empty(), - "Cannot create an index when no tables exist in the environment." - ); - - let table = pick(&env.tables, rng); - - if table.columns.is_empty() { - panic!( - "Cannot create an index on table '{}' as it has no columns.", - table.name - ); - } - - let num_columns_to_pick = rng.gen_range(1..=table.columns.len()); - let picked_column_indices = pick_n_unique(0..table.columns.len(), num_columns_to_pick, rng); - - let columns = picked_column_indices - .into_iter() - .map(|i| { - let column = &table.columns[i]; - ( - column.name.clone(), - if rng.gen_bool(0.5) { - SortOrder::Asc - } else { - SortOrder::Desc - }, - ) - }) - .collect::>(); - - let index_name = format!( - "idx_{}_{}", - table.name, - gen_random_text(rng).chars().take(8).collect::() - ); - - CreateIndex { - index_name, - table_name: table.name.clone(), - columns, - } - } -} diff --git a/simulator/model/query/delete.rs b/simulator/model/query/delete.rs deleted file mode 100644 index 265cdfe96..000000000 --- a/simulator/model/query/delete.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -use super::predicate::Predicate; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Delete { - pub(crate) table: String, - pub(crate) predicate: Predicate, -} - -impl Shadow for Delete { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - let table = tables.tables.iter_mut().find(|t| t.name == self.table); - - if let Some(table) = table { - // If the table exists, we can delete from it - let t2 = table.clone(); - table.rows.retain_mut(|r| !self.predicate.test(r, &t2)); - } else { - // If the table does not exist, we return an error - return Err(anyhow::anyhow!( - "Table {} does not exist. DELETE statement ignored.", - self.table - )); - } - - Ok(vec![]) - } -} - -impl Display for Delete { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DELETE FROM {} WHERE {}", self.table, self.predicate) - } -} diff --git a/simulator/model/query/drop.rs b/simulator/model/query/drop.rs deleted file mode 100644 index 2b4379ff9..000000000 --- a/simulator/model/query/drop.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Drop { - pub(crate) table: String, -} - -impl Shadow for Drop { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if !tables.iter().any(|t| t.name == self.table) { - // If the table does not exist, we return an error - return Err(anyhow::anyhow!( - "Table {} does not exist. DROP statement ignored.", - self.table - )); - } - - tables.tables.retain(|t| t.name != self.table); - - Ok(vec![]) - } -} - -impl Display for Drop { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DROP TABLE {}", self.table) - } -} diff --git a/simulator/model/query/insert.rs b/simulator/model/query/insert.rs deleted file mode 100644 index 3dc8659df..000000000 --- a/simulator/model/query/insert.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -use super::select::Select; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) enum Insert { - Values { - table: String, - values: Vec>, - }, - Select { - table: String, - select: Box
to support resolving a value from another table -// This function attempts to convert an simpler easily computable expression into values -// TODO: In the future, we can try to expand this computation if we want to support harder properties that require us -// to already know more values before hand -pub fn expr_to_value( - expr: &ast::Expr, - row: &[SimValue], - table: &T, -) -> Option { - match expr { - ast::Expr::DoublyQualified(_, _, ast::Name::Ident(col_name)) - | ast::Expr::DoublyQualified(_, _, ast::Name::Quoted(col_name)) - | ast::Expr::Qualified(_, ast::Name::Ident(col_name)) - | ast::Expr::Qualified(_, ast::Name::Quoted(col_name)) - | ast::Expr::Id(ast::Name::Ident(col_name)) => { - let columns = table.columns().collect::>(); - assert_eq!(row.len(), columns.len()); - columns - .iter() - .zip(row.iter()) - .find(|(column, _)| column.column.name == *col_name) - .map(|(_, value)| value) - .cloned() - } - ast::Expr::Literal(literal) => Some(literal.into()), - ast::Expr::Binary(lhs, op, rhs) => { - let lhs = expr_to_value(lhs, row, table)?; - let rhs = expr_to_value(rhs, row, table)?; - Some(lhs.binary_compare(&rhs, *op)) - } - ast::Expr::Like { - lhs, - not, - op, - rhs, - escape: _, // TODO: support escape - } => { - let lhs = expr_to_value(lhs, row, table)?; - let rhs = expr_to_value(rhs, row, table)?; - let res = lhs.like_compare(&rhs, *op); - let value: SimValue = if *not { !res } else { res }.into(); - Some(value) - } - ast::Expr::Unary(op, expr) => { - let value = expr_to_value(expr, row, table)?; - Some(value.unary_exec(*op)) - } - ast::Expr::Parenthesized(exprs) => { - assert_eq!(exprs.len(), 1); - expr_to_value(&exprs[0], row, table) - } - _ => unreachable!("{:?}", expr), - } -} - -impl Display for Predicate { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.to_fmt(f) - } -} diff --git a/simulator/model/query/select.rs b/simulator/model/query/select.rs deleted file mode 100644 index 2dcc762a8..000000000 --- a/simulator/model/query/select.rs +++ /dev/null @@ -1,497 +0,0 @@ -use std::{collections::HashSet, fmt::Display}; - -use anyhow::Context; -pub use ast::Distinctness; -use itertools::Itertools; -use serde::{Deserialize, Serialize}; -use turso_sqlite3_parser::ast::{self, fmt::ToTokens, SortOrder}; - -use crate::{ - generation::Shadow, - model::{ - query::EmptyContext, - table::{JoinTable, JoinType, JoinedTable, SimValue, Table, TableContext}, - }, - runner::env::SimulatorTables, -}; - -use super::predicate::Predicate; - -/// `SELECT` or `RETURNING` result column -// https://sqlite.org/syntax/result-column.html -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum ResultColumn { - /// expression - Expr(Predicate), - /// `*` - Star, - /// column name - Column(String), -} - -impl Display for ResultColumn { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ResultColumn::Expr(expr) => write!(f, "({expr})"), - ResultColumn::Star => write!(f, "*"), - ResultColumn::Column(name) => write!(f, "{name}"), - } - } -} -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Select { - pub(crate) body: SelectBody, - pub(crate) limit: Option, -} - -impl Select { - pub fn simple(table: String, where_clause: Predicate) -> Self { - Self::single( - table, - vec![ResultColumn::Star], - where_clause, - None, - Distinctness::All, - ) - } - - pub fn expr(expr: Predicate) -> Self { - Select { - body: SelectBody { - select: Box::new(SelectInner { - distinctness: Distinctness::All, - columns: vec![ResultColumn::Expr(expr)], - from: None, - where_clause: Predicate::true_(), - order_by: None, - }), - compounds: Vec::new(), - }, - limit: None, - } - } - - pub fn single( - table: String, - result_columns: Vec, - where_clause: Predicate, - limit: Option, - distinct: Distinctness, - ) -> Self { - Select { - body: SelectBody { - select: Box::new(SelectInner { - distinctness: distinct, - columns: result_columns, - from: Some(FromClause { - table, - joins: Vec::new(), - }), - where_clause, - order_by: None, - }), - compounds: Vec::new(), - }, - limit, - } - } - - pub fn compound(left: Select, right: Select, operator: CompoundOperator) -> Self { - let mut body = left.body; - body.compounds.push(CompoundSelect { - operator, - select: Box::new(right.body.select.as_ref().clone()), - }); - Select { - body, - limit: left.limit.or(right.limit), - } - } - - pub(crate) fn dependencies(&self) -> HashSet { - if self.body.select.from.is_none() { - return HashSet::new(); - } - let from = self.body.select.from.as_ref().unwrap(); - let mut tables = HashSet::new(); - tables.insert(from.table.clone()); - - tables.extend(from.dependencies()); - - for compound in &self.body.compounds { - tables.extend( - compound - .select - .from - .as_ref() - .map(|f| f.dependencies()) - .unwrap_or(vec![]) - .into_iter(), - ); - } - - tables - } -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SelectBody { - /// first select - pub select: Box, - /// compounds - pub compounds: Vec, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct OrderBy { - pub columns: Vec<(String, SortOrder)>, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SelectInner { - /// `DISTINCT` - pub distinctness: Distinctness, - /// columns - pub columns: Vec, - /// `FROM` clause - pub from: Option, - /// `WHERE` clause - pub where_clause: Predicate, - /// `ORDER BY` clause - pub order_by: Option, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum CompoundOperator { - /// `UNION` - Union, - /// `UNION ALL` - UnionAll, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CompoundSelect { - /// operator - pub operator: CompoundOperator, - /// select - pub select: Box, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct FromClause { - /// table - pub table: String, - /// `JOIN`ed tables - pub joins: Vec, -} - -impl FromClause { - fn to_sql_ast(&self) -> ast::FromClause { - ast::FromClause { - select: Some(Box::new(ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_str(&self.table)), - None, - None, - ))), - joins: if self.joins.is_empty() { - None - } else { - Some( - self.joins - .iter() - .map(|join| ast::JoinedSelectTable { - operator: match join.join_type { - JoinType::Inner => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER)) - } - JoinType::Left => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::LEFT)) - } - JoinType::Right => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::RIGHT)) - } - JoinType::Full => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::OUTER)) - } - JoinType::Cross => { - ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)) - } - }, - table: ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_str(&join.table)), - None, - None, - ), - constraint: Some(ast::JoinConstraint::On(join.on.0.clone())), - }) - .collect(), - ) - }, - op: None, // FIXME: this is a temporary fix, we should remove this field - } - } - - pub(crate) fn dependencies(&self) -> Vec { - let mut deps = vec![self.table.clone()]; - for join in &self.joins { - deps.push(join.table.clone()); - } - deps - } -} - -impl Shadow for FromClause { - type Result = anyhow::Result; - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - let tables = &mut env.tables; - - let first_table = tables - .iter() - .find(|t| t.name == self.table) - .context("Table not found")?; - - let mut join_table = JoinTable { - tables: vec![first_table.clone()], - rows: Vec::new(), - }; - - for join in &self.joins { - let joined_table = tables - .iter() - .find(|t| t.name == join.table) - .context("Joined table not found")?; - - join_table.tables.push(joined_table.clone()); - - match join.join_type { - JoinType::Inner => { - // Implement inner join logic - let join_rows = joined_table - .rows - .iter() - .filter(|row| join.on.test(row, joined_table)) - .cloned() - .collect::>(); - // take a cartesian product of the rows - let all_row_pairs = join_table - .rows - .clone() - .into_iter() - .cartesian_product(join_rows.iter()); - - for (row1, row2) in all_row_pairs { - let row = row1.iter().chain(row2.iter()).cloned().collect::>(); - - let is_in = join.on.test(&row, &join_table); - - if is_in { - join_table.rows.push(row); - } - } - } - _ => todo!(), - } - } - Ok(join_table) - } -} - -impl Shadow for SelectInner { - type Result = anyhow::Result; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - if let Some(from) = &self.from { - let mut join_table = from.shadow(env)?; - let col_count = join_table.columns().count(); - for row in &mut join_table.rows { - assert_eq!( - row.len(), - col_count, - "Row length does not match column length after join" - ); - } - let join_clone = join_table.clone(); - - join_table - .rows - .retain(|row| self.where_clause.test(row, &join_clone)); - - if self.distinctness == Distinctness::Distinct { - join_table.rows.sort_unstable(); - join_table.rows.dedup(); - } - - Ok(join_table) - } else { - assert!(self - .columns - .iter() - .all(|col| matches!(col, ResultColumn::Expr(_)))); - - // If `WHERE` is false, just return an empty table - if !self.where_clause.test(&[], &Table::anonymous(vec![])) { - return Ok(JoinTable { - tables: Vec::new(), - rows: Vec::new(), - }); - } - - // Compute the results of the column expressions and make a row - let mut row = Vec::new(); - for col in &self.columns { - match col { - ResultColumn::Expr(expr) => { - let value = expr.eval(&[], &Table::anonymous(vec![])); - if let Some(value) = value { - row.push(value); - } else { - return Err(anyhow::anyhow!( - "Failed to evaluate expression in free select ({})", - expr.0.format_with_context(&EmptyContext {}).unwrap() - )); - } - } - _ => unreachable!("Only expressions are allowed in free selects"), - } - } - - Ok(JoinTable { - tables: Vec::new(), - rows: vec![row], - }) - } - } -} - -impl Shadow for Select { - type Result = anyhow::Result>>; - - fn shadow(&self, env: &mut SimulatorTables) -> Self::Result { - let first_result = self.body.select.shadow(env)?; - - let mut rows = first_result.rows; - - for compound in self.body.compounds.iter() { - let compound_results = compound.select.shadow(env)?; - - match compound.operator { - CompoundOperator::Union => { - // Union means we need to combine the results, removing duplicates - let mut new_rows = compound_results.rows; - new_rows.extend(rows.clone()); - new_rows.sort_unstable(); - new_rows.dedup(); - rows = new_rows; - } - CompoundOperator::UnionAll => { - // Union all means we just concatenate the results - rows.extend(compound_results.rows.into_iter()); - } - } - } - - Ok(rows) - } -} - -impl Select { - pub fn to_sql_ast(&self) -> ast::Select { - ast::Select { - with: None, - body: ast::SelectBody { - select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { - distinctness: if self.body.select.distinctness == Distinctness::Distinct { - Some(ast::Distinctness::Distinct) - } else { - None - }, - columns: self - .body - .select - .columns - .iter() - .map(|col| match col { - ResultColumn::Expr(expr) => { - ast::ResultColumn::Expr(expr.0.clone(), None) - } - ResultColumn::Star => ast::ResultColumn::Star, - ResultColumn::Column(name) => ast::ResultColumn::Expr( - ast::Expr::Id(ast::Name::Ident(name.clone())), - None, - ), - }) - .collect(), - from: self.body.select.from.as_ref().map(|f| f.to_sql_ast()), - where_clause: Some(self.body.select.where_clause.0.clone()), - group_by: None, - window_clause: None, - }))), - compounds: Some( - self.body - .compounds - .iter() - .map(|compound| ast::CompoundSelect { - operator: match compound.operator { - CompoundOperator::Union => ast::CompoundOperator::Union, - CompoundOperator::UnionAll => ast::CompoundOperator::UnionAll, - }, - select: Box::new(ast::OneSelect::Select(Box::new(ast::SelectInner { - distinctness: Some(compound.select.distinctness), - columns: compound - .select - .columns - .iter() - .map(|col| match col { - ResultColumn::Expr(expr) => { - ast::ResultColumn::Expr(expr.0.clone(), None) - } - ResultColumn::Star => ast::ResultColumn::Star, - ResultColumn::Column(name) => ast::ResultColumn::Expr( - ast::Expr::Id(ast::Name::Ident(name.clone())), - None, - ), - }) - .collect(), - from: compound.select.from.as_ref().map(|f| f.to_sql_ast()), - where_clause: Some(compound.select.where_clause.0.clone()), - group_by: None, - window_clause: None, - }))), - }) - .collect(), - ), - }, - order_by: self.body.select.order_by.as_ref().map(|o| { - o.columns - .iter() - .map(|(name, order)| ast::SortedColumn { - expr: ast::Expr::Id(ast::Name::Ident(name.clone())), - order: match order { - SortOrder::Asc => Some(ast::SortOrder::Asc), - SortOrder::Desc => Some(ast::SortOrder::Desc), - }, - nulls: None, - }) - .collect() - }), - limit: self.limit.map(|l| { - Box::new(ast::Limit { - expr: ast::Expr::Literal(ast::Literal::Numeric(l.to_string())), - offset: None, - }) - }), - } - } -} -impl Display for Select { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.to_sql_ast().to_fmt_with_context(f, &EmptyContext {}) - } -} - -#[cfg(test)] -mod select_tests { - - #[test] - fn test_select_display() {} -} diff --git a/simulator/model/query/transaction.rs b/simulator/model/query/transaction.rs deleted file mode 100644 index a73fb076e..000000000 --- a/simulator/model/query/transaction.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Begin { - pub(crate) immediate: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Commit; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Rollback; - -impl Shadow for Begin { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - tables.snapshot = Some(tables.tables.clone()); - vec![] - } -} - -impl Shadow for Commit { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - tables.snapshot = None; - vec![] - } -} - -impl Shadow for Rollback { - type Result = Vec>; - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - if let Some(tables_) = tables.snapshot.take() { - tables.tables = tables_; - } - vec![] - } -} - -impl Display for Begin { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BEGIN {}", if self.immediate { "IMMEDIATE" } else { "" }) - } -} - -impl Display for Commit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "COMMIT") - } -} - -impl Display for Rollback { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ROLLBACK") - } -} diff --git a/simulator/model/query/update.rs b/simulator/model/query/update.rs deleted file mode 100644 index a4cc13fa8..000000000 --- a/simulator/model/query/update.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; - -use crate::{generation::Shadow, model::table::SimValue, runner::env::SimulatorTables}; - -use super::predicate::Predicate; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub(crate) struct Update { - pub(crate) table: String, - pub(crate) set_values: Vec<(String, SimValue)>, // Pair of value for set expressions => SET name=value - pub(crate) predicate: Predicate, -} - -impl Update { - pub fn table(&self) -> &str { - &self.table - } -} - -impl Shadow for Update { - type Result = anyhow::Result>>; - - fn shadow(&self, tables: &mut SimulatorTables) -> Self::Result { - let table = tables.tables.iter_mut().find(|t| t.name == self.table); - - let table = if let Some(table) = table { - table - } else { - return Err(anyhow::anyhow!( - "Table {} does not exist. UPDATE statement ignored.", - self.table - )); - }; - - let t2 = table.clone(); - for row in table - .rows - .iter_mut() - .filter(|r| self.predicate.test(r, &t2)) - { - for (column, set_value) in &self.set_values { - if let Some((idx, _)) = table - .columns - .iter() - .enumerate() - .find(|(_, c)| &c.name == column) - { - row[idx] = set_value.clone(); - } - } - } - - Ok(vec![]) - } -} - -impl Display for Update { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "UPDATE {} SET ", self.table)?; - for (i, (name, value)) in self.set_values.iter().enumerate() { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "{name} = {value}")?; - } - write!(f, " WHERE {}", self.predicate)?; - Ok(()) - } -} diff --git a/simulator/model/table.rs b/simulator/model/table.rs deleted file mode 100644 index b69a197d0..000000000 --- a/simulator/model/table.rs +++ /dev/null @@ -1,428 +0,0 @@ -use std::{fmt::Display, hash::Hash, ops::Deref}; - -use serde::{Deserialize, Serialize}; -use turso_core::{numeric::Numeric, types}; -use turso_sqlite3_parser::ast; - -use crate::model::query::predicate::Predicate; - -pub(crate) struct Name(pub(crate) String); - -impl Deref for Name { - type Target = str; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[derive(Debug, Clone, Copy)] -pub struct ContextColumn<'a> { - pub table_name: &'a str, - pub column: &'a Column, -} - -pub trait TableContext { - fn columns<'a>(&'a self) -> impl Iterator>; - fn rows(&self) -> &Vec>; -} - -impl TableContext for Table { - fn columns<'a>(&'a self) -> impl Iterator> { - self.columns.iter().map(|col| ContextColumn { - column: col, - table_name: &self.name, - }) - } - - fn rows(&self) -> &Vec> { - &self.rows - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Table { - pub(crate) name: String, - pub(crate) columns: Vec, - pub(crate) rows: Vec>, - pub(crate) indexes: Vec, -} - -impl Table { - pub fn anonymous(rows: Vec>) -> Self { - Self { - rows, - name: "".to_string(), - columns: vec![], - indexes: vec![], - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Column { - pub(crate) name: String, - pub(crate) column_type: ColumnType, - pub(crate) primary: bool, - pub(crate) unique: bool, -} - -// Uniquely defined by name in this case -impl Hash for Column { - fn hash(&self, state: &mut H) { - self.name.hash(state); - } -} - -impl PartialEq for Column { - fn eq(&self, other: &Self) -> bool { - self.name == other.name - } -} - -impl Eq for Column {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum ColumnType { - Integer, - Float, - Text, - Blob, -} - -impl Display for ColumnType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Integer => write!(f, "INTEGER"), - Self::Float => write!(f, "REAL"), - Self::Text => write!(f, "TEXT"), - Self::Blob => write!(f, "BLOB"), - } - } -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct JoinedTable { - /// table name - pub table: String, - /// `JOIN` type - pub join_type: JoinType, - /// `ON` clause - pub on: Predicate, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum JoinType { - Inner, - Left, - Right, - Full, - Cross, -} - -impl TableContext for JoinTable { - fn columns<'a>(&'a self) -> impl Iterator> { - self.tables.iter().flat_map(|table| table.columns()) - } - - fn rows(&self) -> &Vec> { - &self.rows - } -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct JoinTable { - pub tables: Vec
, - pub rows: Vec>, -} - -fn float_to_string(float: &f64, serializer: S) -> Result -where - S: serde::Serializer, -{ - serializer.serialize_str(&format!("{float}")) -} - -fn string_to_float<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - s.parse().map_err(serde::de::Error::custom) -} - -#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)] -pub(crate) struct SimValue(pub turso_core::Value); - -fn to_sqlite_blob(bytes: &[u8]) -> String { - format!( - "X'{}'", - bytes - .iter() - .fold(String::new(), |acc, b| acc + &format!("{b:02X}")) - ) -} - -impl Display for SimValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.0 { - types::Value::Null => write!(f, "NULL"), - types::Value::Integer(i) => write!(f, "{i}"), - types::Value::Float(fl) => write!(f, "{fl}"), - value @ types::Value::Text(..) => write!(f, "'{value}'"), - types::Value::Blob(b) => write!(f, "{}", to_sqlite_blob(b)), - } - } -} - -impl SimValue { - pub const FALSE: Self = SimValue(types::Value::Integer(0)); - pub const TRUE: Self = SimValue(types::Value::Integer(1)); - - pub fn as_bool(&self) -> bool { - Numeric::from(&self.0).try_into_bool().unwrap_or_default() - } - - // TODO: support more predicates - /// Returns a Result of a Binary Operation - /// - /// TODO: forget collations for now - /// TODO: have the [ast::Operator::Equals], [ast::Operator::NotEquals], [ast::Operator::Greater], - /// [ast::Operator::GreaterEquals], [ast::Operator::Less], [ast::Operator::LessEquals] function to be extracted - /// into its functions in turso_core so that it can be used here - pub fn binary_compare(&self, other: &Self, operator: ast::Operator) -> SimValue { - match operator { - ast::Operator::Add => self.0.exec_add(&other.0).into(), - ast::Operator::And => self.0.exec_and(&other.0).into(), - ast::Operator::ArrowRight => todo!(), - ast::Operator::ArrowRightShift => todo!(), - ast::Operator::BitwiseAnd => self.0.exec_bit_and(&other.0).into(), - ast::Operator::BitwiseOr => self.0.exec_bit_or(&other.0).into(), - ast::Operator::BitwiseNot => todo!(), // TODO: Do not see any function usage of this operator in Core - ast::Operator::Concat => self.0.exec_concat(&other.0).into(), - ast::Operator::Equals => (self == other).into(), - ast::Operator::Divide => self.0.exec_divide(&other.0).into(), - ast::Operator::Greater => (self > other).into(), - ast::Operator::GreaterEquals => (self >= other).into(), - // TODO: Test these implementations - ast::Operator::Is => match (&self.0, &other.0) { - (types::Value::Null, types::Value::Null) => true.into(), - (types::Value::Null, _) => false.into(), - (_, types::Value::Null) => false.into(), - _ => self.binary_compare(other, ast::Operator::Equals), - }, - ast::Operator::IsNot => self - .binary_compare(other, ast::Operator::Is) - .unary_exec(ast::UnaryOperator::Not), - ast::Operator::LeftShift => self.0.exec_shift_left(&other.0).into(), - ast::Operator::Less => (self < other).into(), - ast::Operator::LessEquals => (self <= other).into(), - ast::Operator::Modulus => self.0.exec_remainder(&other.0).into(), - ast::Operator::Multiply => self.0.exec_multiply(&other.0).into(), - ast::Operator::NotEquals => (self != other).into(), - ast::Operator::Or => self.0.exec_or(&other.0).into(), - ast::Operator::RightShift => self.0.exec_shift_right(&other.0).into(), - ast::Operator::Subtract => self.0.exec_subtract(&other.0).into(), - } - } - - // TODO: support more operators. Copy the implementation for exec_glob - pub fn like_compare(&self, other: &Self, operator: ast::LikeOperator) -> bool { - match operator { - ast::LikeOperator::Glob => todo!(), - ast::LikeOperator::Like => { - // TODO: support ESCAPE `expr` option in AST - // TODO: regex cache - types::Value::exec_like( - None, - other.0.to_string().as_str(), - self.0.to_string().as_str(), - ) - } - ast::LikeOperator::Match => todo!(), - ast::LikeOperator::Regexp => todo!(), - } - } - - pub fn unary_exec(&self, operator: ast::UnaryOperator) -> SimValue { - let new_value = match operator { - ast::UnaryOperator::BitwiseNot => self.0.exec_bit_not(), - ast::UnaryOperator::Negative => { - SimValue(types::Value::Integer(0)) - .binary_compare(self, ast::Operator::Subtract) - .0 - } - ast::UnaryOperator::Not => self.0.exec_boolean_not(), - ast::UnaryOperator::Positive => self.0.clone(), - }; - Self(new_value) - } -} - -impl From for SimValue { - fn from(value: ast::Literal) -> Self { - Self::from(&value) - } -} - -/// Converts a SQL string literal with already-escaped single quotes to a regular string by: -/// - Removing the enclosing single quotes -/// - Converting sequences of 2N single quotes ('''''') to N single quotes (''') -/// -/// Assumes: -/// - The input starts and ends with a single quote -/// - The input contains a valid amount of single quotes inside the enclosing quotes; -/// i.e. any ' is escaped as a double '' -fn unescape_singlequotes(input: &str) -> String { - assert!( - input.starts_with('\'') && input.ends_with('\''), - "Input string must be wrapped in single quotes" - ); - // Skip first and last characters (the enclosing quotes) - let inner = &input[1..input.len() - 1]; - - let mut result = String::with_capacity(inner.len()); - let mut chars = inner.chars().peekable(); - - while let Some(c) = chars.next() { - if c == '\'' { - // Count consecutive single quotes - let mut quote_count = 1; - while chars.peek() == Some(&'\'') { - quote_count += 1; - chars.next(); - } - assert!( - quote_count % 2 == 0, - "Expected even number of quotes, got {quote_count} in string {input}" - ); - // For every pair of quotes, output one quote - for _ in 0..(quote_count / 2) { - result.push('\''); - } - } else { - result.push(c); - } - } - - result -} - -/// Escapes a string by doubling contained single quotes and then wrapping it in single quotes. -fn escape_singlequotes(input: &str) -> String { - let mut result = String::with_capacity(input.len() + 2); - result.push('\''); - result.push_str(&input.replace("'", "''")); - result.push('\''); - result -} - -impl From<&ast::Literal> for SimValue { - fn from(value: &ast::Literal) -> Self { - let new_value = match value { - ast::Literal::Null => types::Value::Null, - ast::Literal::Numeric(number) => Numeric::from(number).into(), - ast::Literal::String(string) => types::Value::build_text(unescape_singlequotes(string)), - ast::Literal::Blob(blob) => types::Value::Blob( - blob.as_bytes() - .chunks_exact(2) - .map(|pair| { - // We assume that sqlite3-parser has already validated that - // the input is valid hex string, thus unwrap is safe. - let hex_byte = std::str::from_utf8(pair).unwrap(); - u8::from_str_radix(hex_byte, 16).unwrap() - }) - .collect(), - ), - ast::Literal::Keyword(keyword) => match keyword.to_uppercase().as_str() { - "TRUE" => types::Value::Integer(1), - "FALSE" => types::Value::Integer(0), - "NULL" => types::Value::Null, - _ => unimplemented!("Unsupported keyword literal: {}", keyword), - }, - lit => unimplemented!("{:?}", lit), - }; - Self(new_value) - } -} - -impl From for ast::Literal { - fn from(value: SimValue) -> Self { - Self::from(&value) - } -} - -impl From<&SimValue> for ast::Literal { - fn from(value: &SimValue) -> Self { - match &value.0 { - types::Value::Null => Self::Null, - types::Value::Integer(i) => Self::Numeric(i.to_string()), - types::Value::Float(f) => Self::Numeric(f.to_string()), - text @ types::Value::Text(..) => Self::String(escape_singlequotes(&text.to_string())), - types::Value::Blob(blob) => Self::Blob(hex::encode(blob)), - } - } -} - -impl From for SimValue { - fn from(value: bool) -> Self { - if value { - SimValue::TRUE - } else { - SimValue::FALSE - } - } -} - -impl From for turso_core::types::Value { - fn from(value: SimValue) -> Self { - value.0 - } -} - -impl From<&SimValue> for turso_core::types::Value { - fn from(value: &SimValue) -> Self { - value.0.clone() - } -} - -impl From for SimValue { - fn from(value: turso_core::types::Value) -> Self { - Self(value) - } -} - -impl From<&turso_core::types::Value> for SimValue { - fn from(value: &turso_core::types::Value) -> Self { - Self(value.clone()) - } -} - -#[cfg(test)] -mod tests { - use crate::model::table::{escape_singlequotes, unescape_singlequotes}; - - #[test] - fn test_unescape_singlequotes() { - assert_eq!(unescape_singlequotes("'hello'"), "hello"); - assert_eq!(unescape_singlequotes("'O''Reilly'"), "O'Reilly"); - assert_eq!( - unescape_singlequotes("'multiple''single''quotes'"), - "multiple'single'quotes" - ); - assert_eq!(unescape_singlequotes("'test''''test'"), "test''test"); - assert_eq!(unescape_singlequotes("'many''''''quotes'"), "many'''quotes"); - } - - #[test] - fn test_escape_singlequotes() { - assert_eq!(escape_singlequotes("hello"), "'hello'"); - assert_eq!(escape_singlequotes("O'Reilly"), "'O''Reilly'"); - assert_eq!( - escape_singlequotes("multiple'single'quotes"), - "'multiple''single''quotes'" - ); - assert_eq!(escape_singlequotes("test''test"), "'test''''test'"); - assert_eq!(escape_singlequotes("many'''quotes"), "'many''''''quotes'"); - } -} diff --git a/simulator/runner/clock.rs b/simulator/runner/clock.rs index ef687c5c1..871a01346 100644 --- a/simulator/runner/clock.rs +++ b/simulator/runner/clock.rs @@ -27,7 +27,7 @@ impl SimulatorClock { let nanos = self .rng .borrow_mut() - .gen_range(self.min_tick..self.max_tick); + .random_range(self.min_tick..self.max_tick); let nanos = std::time::Duration::from_micros(nanos); *time += nanos; *time diff --git a/simulator/runner/differential.rs b/simulator/runner/differential.rs index 7d37babe7..5723418c1 100644 --- a/simulator/runner/differential.rs +++ b/simulator/runner/differential.rs @@ -1,14 +1,14 @@ use std::sync::{Arc, Mutex}; +use sql_generation::{generation::pick_index, model::table::SimValue}; use turso_core::Value; use crate::{ generation::{ - pick_index, plan::{Interaction, InteractionPlanState, ResultSet}, Shadow as _, }, - model::{query::Query, table::SimValue}, + model::Query, runner::execution::ExecutionContinuation, InteractionPlan, }; diff --git a/simulator/runner/doublecheck.rs b/simulator/runner/doublecheck.rs index 5ba98ca50..7c9d33b4e 100644 --- a/simulator/runner/doublecheck.rs +++ b/simulator/runner/doublecheck.rs @@ -3,9 +3,10 @@ use std::{ sync::{Arc, Mutex}, }; +use sql_generation::generation::pick_index; + use crate::{ - generation::{pick_index, plan::InteractionPlanState}, - runner::execution::ExecutionContinuation, + generation::plan::InteractionPlanState, runner::execution::ExecutionContinuation, InteractionPlan, }; diff --git a/simulator/runner/env.rs b/simulator/runner/env.rs index f5787bc57..a29adc591 100644 --- a/simulator/runner/env.rs +++ b/simulator/runner/env.rs @@ -7,10 +7,9 @@ use std::sync::Arc; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use sql_generation::model::table::Table; use turso_core::Database; -use crate::model::table::Table; - use crate::runner::io::SimulatorIO; use super::cli::SimulatorCLI; @@ -173,29 +172,29 @@ impl SimulatorEnv { let mut delete_percent = 0.0; let mut update_percent = 0.0; - let read_percent = rng.gen_range(0.0..=total); + let read_percent = rng.random_range(0.0..=total); let write_percent = total - read_percent; if !cli_opts.disable_create { // Create percent should be 5-15% of the write percent - create_percent = rng.gen_range(0.05..=0.15) * write_percent; + create_percent = rng.random_range(0.05..=0.15) * write_percent; } if !cli_opts.disable_create_index { // Create indexpercent should be 2-5% of the write percent - create_index_percent = rng.gen_range(0.02..=0.05) * write_percent; + create_index_percent = rng.random_range(0.02..=0.05) * write_percent; } if !cli_opts.disable_drop { // Drop percent should be 2-5% of the write percent - drop_percent = rng.gen_range(0.02..=0.05) * write_percent; + drop_percent = rng.random_range(0.02..=0.05) * write_percent; } if !cli_opts.disable_delete { // Delete percent should be 10-20% of the write percent - delete_percent = rng.gen_range(0.1..=0.2) * write_percent; + delete_percent = rng.random_range(0.1..=0.2) * write_percent; } if !cli_opts.disable_update { // Update percent should be 10-20% of the write percent // TODO: freestyling the percentage - update_percent = rng.gen_range(0.1..=0.2) * write_percent; + update_percent = rng.random_range(0.1..=0.2) * write_percent; } let write_percent = write_percent @@ -220,10 +219,10 @@ impl SimulatorEnv { let opts = SimulatorOpts { seed, - ticks: rng.gen_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), + ticks: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), max_connections: 1, // TODO: for now let's use one connection as we didn't implement // correct transactions processing - max_tables: rng.gen_range(0..128), + max_tables: rng.random_range(0..128), create_percent, create_index_percent, read_percent, @@ -243,7 +242,7 @@ impl SimulatorEnv { disable_fsync_no_wait: cli_opts.disable_fsync_no_wait, disable_faulty_query: cli_opts.disable_faulty_query, page_size: 4096, // TODO: randomize this too - max_interactions: rng.gen_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), + max_interactions: rng.random_range(cli_opts.minimum_tests..=cli_opts.maximum_tests), max_time_simulation: cli_opts.maximum_time, disable_reopen_database: cli_opts.disable_reopen_database, latency_probability: cli_opts.latency_probability, diff --git a/simulator/runner/execution.rs b/simulator/runner/execution.rs index 9cbac3826..fa3dcbff9 100644 --- a/simulator/runner/execution.rs +++ b/simulator/runner/execution.rs @@ -1,10 +1,10 @@ use std::sync::{Arc, Mutex}; +use sql_generation::generation::pick_index; use tracing::instrument; use turso_core::{Connection, LimboError, Result, StepResult}; use crate::generation::{ - pick_index, plan::{Interaction, InteractionPlan, InteractionPlanState, ResultSet}, Shadow as _, }; diff --git a/simulator/runner/file.rs b/simulator/runner/file.rs index c8c5ff4fa..bbda05b1d 100644 --- a/simulator/runner/file.rs +++ b/simulator/runner/file.rs @@ -9,7 +9,7 @@ use rand_chacha::ChaCha8Rng; use tracing::{instrument, Level}; use turso_core::{File, Result}; -use crate::{model::FAULT_ERROR_MSG, runner::clock::SimulatorClock}; +use crate::runner::{clock::SimulatorClock, FAULT_ERROR_MSG}; pub(crate) struct SimulatorFile { pub path: String, pub(crate) inner: Arc, @@ -100,10 +100,10 @@ impl SimulatorFile { fn generate_latency_duration(&self) -> Option { let mut rng = self.rng.borrow_mut(); // Chance to introduce some latency - rng.gen_bool(self.latency_probability as f64 / 100.0) + rng.random_bool(self.latency_probability as f64 / 100.0) .then(|| { let now = self.clock.now(); - let sum = now + std::time::Duration::from_millis(rng.gen_range(5..20)); + let sum = now + std::time::Duration::from_millis(rng.random_range(5..20)); sum.into() }) } diff --git a/simulator/runner/mod.rs b/simulator/runner/mod.rs index b56335da5..3eef78331 100644 --- a/simulator/runner/mod.rs +++ b/simulator/runner/mod.rs @@ -9,3 +9,5 @@ pub mod execution; pub mod file; pub mod io; pub mod watch; + +pub const FAULT_ERROR_MSG: &str = "Injected Fault"; diff --git a/simulator/runner/watch.rs b/simulator/runner/watch.rs index 90e8edc68..feab80af1 100644 --- a/simulator/runner/watch.rs +++ b/simulator/runner/watch.rs @@ -1,10 +1,9 @@ use std::sync::{Arc, Mutex}; +use sql_generation::generation::pick_index; + use crate::{ - generation::{ - pick_index, - plan::{Interaction, InteractionPlanState}, - }, + generation::plan::{Interaction, InteractionPlanState}, runner::execution::ExecutionContinuation, }; diff --git a/simulator/shrink/plan.rs b/simulator/shrink/plan.rs index f08ccbb5a..bccd07afd 100644 --- a/simulator/shrink/plan.rs +++ b/simulator/shrink/plan.rs @@ -1,9 +1,9 @@ -use crate::model::query::Query; use crate::{ generation::{ plan::{Interaction, InteractionPlan, Interactions}, property::Property, }, + model::Query, run_simulation, runner::execution::Execution, SandboxedResult, SimulatorEnv, diff --git a/sql_generation/generation/mod.rs b/sql_generation/generation/mod.rs index d9b7e0cbc..25bd7ec09 100644 --- a/sql_generation/generation/mod.rs +++ b/sql_generation/generation/mod.rs @@ -10,15 +10,16 @@ pub mod predicate; pub mod query; pub mod table; +#[derive(Debug, Clone, Copy)] pub struct Opts { /// Indexes enabled - indexes: bool, + pub indexes: bool, } /// Trait used to provide context to generation functions pub trait GenerationContext { fn tables(&self) -> &Vec
; - fn opts(&self) -> &Opts; + fn opts(&self) -> Opts; } type ArbitraryFromFunc<'a, R, T> = Box T + 'a>; diff --git a/sql_generation/model/query/mod.rs b/sql_generation/model/query/mod.rs index 016ebec29..5bf0cecde 100644 --- a/sql_generation/model/query/mod.rs +++ b/sql_generation/model/query/mod.rs @@ -1,16 +1,10 @@ -use std::{collections::HashSet, fmt::Display}; - pub use create::Create; pub use create_index::CreateIndex; pub use delete::Delete; pub use drop::Drop; pub use insert::Insert; pub use select::Select; -use serde::{Deserialize, Serialize}; use turso_parser::ast::fmt::ToSqlContext; -use update::Update; - -use crate::model::query::transaction::{Begin, Commit, Rollback}; pub mod create; pub mod create_index; @@ -22,69 +16,6 @@ pub mod select; pub mod transaction; pub mod update; -// This type represents the potential queries on the database. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Query { - Create(Create), - Select(Select), - Insert(Insert), - Delete(Delete), - Update(Update), - Drop(Drop), - CreateIndex(CreateIndex), - Begin(Begin), - Commit(Commit), - Rollback(Rollback), -} - -impl Query { - pub fn dependencies(&self) -> HashSet { - match self { - Query::Select(select) => select.dependencies(), - Query::Create(_) => HashSet::new(), - Query::Insert(Insert::Select { table, .. }) - | Query::Insert(Insert::Values { table, .. }) - | Query::Delete(Delete { table, .. }) - | Query::Update(Update { table, .. }) - | Query::Drop(Drop { table, .. }) => HashSet::from_iter([table.clone()]), - Query::CreateIndex(CreateIndex { table_name, .. }) => { - HashSet::from_iter([table_name.clone()]) - } - Query::Begin(_) | Query::Commit(_) | Query::Rollback(_) => HashSet::new(), - } - } - pub fn uses(&self) -> Vec { - match self { - Query::Create(Create { table }) => vec![table.name.clone()], - Query::Select(select) => select.dependencies().into_iter().collect(), - Query::Insert(Insert::Select { table, .. }) - | Query::Insert(Insert::Values { table, .. }) - | Query::Delete(Delete { table, .. }) - | Query::Update(Update { table, .. }) - | Query::Drop(Drop { table, .. }) => vec![table.clone()], - Query::CreateIndex(CreateIndex { table_name, .. }) => vec![table_name.clone()], - Query::Begin(..) | Query::Commit(..) | Query::Rollback(..) => vec![], - } - } -} - -impl Display for Query { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Create(create) => write!(f, "{create}"), - Self::Select(select) => write!(f, "{select}"), - Self::Insert(insert) => write!(f, "{insert}"), - Self::Delete(delete) => write!(f, "{delete}"), - Self::Update(update) => write!(f, "{update}"), - Self::Drop(drop) => write!(f, "{drop}"), - Self::CreateIndex(create_index) => write!(f, "{create_index}"), - Self::Begin(begin) => write!(f, "{begin}"), - Self::Commit(commit) => write!(f, "{commit}"), - Self::Rollback(rollback) => write!(f, "{rollback}"), - } - } -} - /// Used to print sql strings that already have all the context it needs pub struct EmptyContext; From ec73b809a9564de74fc63f1178badf5ed1c30967 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Tue, 26 Aug 2025 08:37:35 +0300 Subject: [PATCH 64/73] antithesis-tests: Enable multi-threading --- antithesis-tests/stress/singleton_driver_stress.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antithesis-tests/stress/singleton_driver_stress.sh b/antithesis-tests/stress/singleton_driver_stress.sh index fcab5ce2a..38c7392b4 100755 --- a/antithesis-tests/stress/singleton_driver_stress.sh +++ b/antithesis-tests/stress/singleton_driver_stress.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -/bin/turso_stress --silent --nr-iterations 10000 +/bin/turso_stress --silent --nr-threads 2 --nr-iterations 10000 From e52f807c7df55d186657efc6fa58a8a75d8b48ed Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 26 Aug 2025 09:08:48 +0300 Subject: [PATCH 65/73] Fix: return NULL for rowid() when cursor's null flag is on Fixes TPC-H query 13 from returning an incorrect result. In this specific case, we were returning non-null `IdxRowid` values for the right-hand side table even when there was no match with the left-hand side table, meaning the join produced matches even in cases where there shouldn't have been any. Closes #2794 --- core/storage/btree.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/storage/btree.rs b/core/storage/btree.rs index b2c36a560..58be7a5e4 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -4320,6 +4320,9 @@ impl BTreeCursor { return Ok(IOResult::Done(None)); } } + if self.get_null_flag() { + return Ok(IOResult::Done(None)); + } if self.has_record.get() { let page = self.stack.top(); let page = page.get(); From 3905f0af46a9a39a1caef9f7cd7c0439712a8b87 Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 26 Aug 2025 09:21:58 +0300 Subject: [PATCH 66/73] Add regression test for issue 2794 --- testing/join.test | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/testing/join.test b/testing/join.test index db25128ec..853ccc875 100755 --- a/testing/join.test +++ b/testing/join.test @@ -302,3 +302,12 @@ do_execsql_test left-join-backwards-iteration { } {12|Alan| 11|Travis|accessories 10|Daniel|coat} + +# regression test for issue 2794: not nulling out rowid properly when left join does not match +do_execsql_test_on_specific_db {:memory:} min-null-regression-test { + create table t (x integer primary key, y); + create table u (x integer primary key, y); + insert into t values (1,1),(2,2); + insert into u values (1,1),(3,3); + select count(u.x) from t left join u using(y); +} {1} \ No newline at end of file From bf58d179db224943c3e5fc2c9998a97cf277c70e Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 26 Aug 2025 10:07:18 +0300 Subject: [PATCH 67/73] Improve documentation of page pinning --- core/storage/pager.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/core/storage/pager.rs b/core/storage/pager.rs index cf4476dc7..fcee1ec4c 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -124,6 +124,14 @@ pub struct PageInner { pub flags: AtomicUsize, pub contents: Option, pub id: usize, + /// If >0, the page is pinned and not eligible for eviction from the page cache. + /// The reason this is a counter is that multiple nested code paths may signal that + /// a page must not be evicted from the page cache, so even if an inner code path + /// requests unpinning via [Page::unpin], the pin count will still be >0 if the outer + /// code path has not yet requested to unpin the page as well. + /// + /// Note that [DumbLruPageCache::clear] evicts the pages even if pinned, so as long as + /// we clear the page cache on errors, pins will not 'leak'. pub pin_count: AtomicUsize, /// The WAL frame number this page was loaded from (0 if loaded from main DB file) /// This tracks which version of the page we have in memory @@ -238,12 +246,13 @@ impl Page { } } - /// Pin the page to prevent it from being evicted from the page cache. + /// Increment the pin count by 1. A pin count >0 means the page is pinned and not eligible for eviction from the page cache. pub fn pin(&self) { self.get().pin_count.fetch_add(1, Ordering::Relaxed); } - /// Unpin the page to allow it to be evicted from the page cache. + /// Decrement the pin count by 1. If the count reaches 0, the page is no longer + /// pinned and is eligible for eviction from the page cache. pub fn unpin(&self) { let was_pinned = self.try_unpin(); @@ -254,8 +263,8 @@ impl Page { ); } - /// Try to unpin the page if it's pinned, otherwise do nothing. - /// Returns true if the page was originally pinned. + /// Try to decrement the pin count by 1, but do nothing if it was already 0. + /// Returns true if the pin count was decremented. pub fn try_unpin(&self) -> bool { self.get() .pin_count @@ -269,6 +278,7 @@ impl Page { .is_ok() } + /// Returns true if the page is pinned and thus not eligible for eviction from the page cache. pub fn is_pinned(&self) -> bool { self.get().pin_count.load(Ordering::Acquire) > 0 } From e65742e5ffc1e3104213cc01b5d7b7c57dbc04eb Mon Sep 17 00:00:00 2001 From: Jussi Saurio Date: Tue, 26 Aug 2025 11:19:29 +0300 Subject: [PATCH 68/73] Fail CI if tursodb output differs from sqlite in tpc-h queries --- perf/tpc-h/run.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/perf/tpc-h/run.sh b/perf/tpc-h/run.sh index 7bea14c23..1d572f388 100755 --- a/perf/tpc-h/run.sh +++ b/perf/tpc-h/run.sh @@ -50,6 +50,8 @@ echo "Starting TPC-H query timing comparison..." echo "The script might ask you to enter the password for sudo, in order to clear system caches." clear_caches +exit_code=0 + for query_file in $(ls "$QUERIES_DIR"/*.sql | sort -V); do if [ -f "$query_file" ]; then query_name=$(basename "$query_file") @@ -85,6 +87,7 @@ for query_file in $(ls "$QUERIES_DIR"/*.sql | sort -V); do if [ -n "$output_diff" ]; then echo "Output difference:" echo "$output_diff" + exit_code=1 else echo "No output difference" fi @@ -96,3 +99,8 @@ done echo "-----------------------------------------------------------" echo "TPC-H query timing comparison completed." + +if [ $exit_code -ne 0 ]; then + echo "Error: Output differences found" + exit $exit_code +fi From 26ba09c45ffd57c05a9feb918f44210666551548 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Tue, 26 Aug 2025 14:58:21 +0300 Subject: [PATCH 69/73] Revert "Merge 'Remove double indirection in the Parser' from Pedro Muniz" This reverts commit 71c1b357e4eb53d6eac2c44a4496dddac02d02c8, reversing changes made to 6bc568ff69f6a38f9bd09b0994025d9338a6abd6 because it actually makes things slower. --- core/incremental/expr_compiler.rs | 4 +- core/translate/display.rs | 17 +- core/translate/expr.rs | 10 +- core/translate/group_by.rs | 6 +- core/translate/insert.rs | 9 +- .../optimizer/lift_common_subexpressions.rs | 41 +- core/translate/optimizer/mod.rs | 2 +- core/translate/optimizer/order.rs | 4 +- core/translate/order_by.rs | 6 +- core/translate/plan.rs | 8 +- core/translate/planner.rs | 8 +- core/translate/select.rs | 19 +- core/translate/update.rs | 4 +- core/util.rs | 20 +- parser/src/ast.rs | 134 +---- parser/src/parser.rs | 551 +++++++++++------- 16 files changed, 429 insertions(+), 414 deletions(-) diff --git a/core/incremental/expr_compiler.rs b/core/incremental/expr_compiler.rs index dae0687a2..f94d72a2a 100644 --- a/core/incremental/expr_compiler.rs +++ b/core/incremental/expr_compiler.rs @@ -67,7 +67,7 @@ fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { distinctness: *distinctness, args: args .iter() - .map(|arg| transform_expr_for_dbsp(arg, input_column_names)) + .map(|arg| Box::new(transform_expr_for_dbsp(arg, input_column_names))) .collect(), order_by: order_by.clone(), filter_over: filter_over.clone(), @@ -75,7 +75,7 @@ fn transform_expr_for_dbsp(expr: &Expr, input_column_names: &[String]) -> Expr { Expr::Parenthesized(exprs) => Expr::Parenthesized( exprs .iter() - .map(|e| transform_expr_for_dbsp(e, input_column_names)) + .map(|e| Box::new(transform_expr_for_dbsp(e, input_column_names))) .collect(), ), // For other expression types, keep as is diff --git a/core/translate/display.rs b/core/translate/display.rs index e69c843dd..631e5a295 100644 --- a/core/translate/display.rs +++ b/core/translate/display.rs @@ -368,8 +368,13 @@ impl ToTokens for SelectPlan { context: &C, ) -> Result<(), S::Error> { if !self.values.is_empty() { - ast::OneSelect::Values(self.values.iter().map(|values| values.to_vec()).collect()) - .to_tokens_with_context(s, context)?; + ast::OneSelect::Values( + self.values + .iter() + .map(|values| values.iter().map(|v| Box::from(v.clone())).collect()) + .collect(), + ) + .to_tokens_with_context(s, context)?; } else { s.append(TokenType::TK_SELECT, None)?; if self.distinctness.is_distinct() { @@ -443,7 +448,7 @@ impl ToTokens for SelectPlan { s.comma( self.order_by.iter().map(|(expr, order)| ast::SortedColumn { - expr: expr.clone().into_boxed(), + expr: expr.clone(), order: Some(*order), nulls: None, }), @@ -505,7 +510,7 @@ impl ToTokens for DeletePlan { s.comma( self.order_by.iter().map(|(expr, order)| ast::SortedColumn { - expr: expr.clone().into_boxed(), + expr: expr.clone(), order: Some(*order), nulls: None, }), @@ -558,7 +563,7 @@ impl ToTokens for UpdatePlan { ast::Set { col_names: vec![ast::Name::new(col_name)], - expr: set_expr.clone().into_boxed(), + expr: set_expr.clone(), } }), context, @@ -586,7 +591,7 @@ impl ToTokens for UpdatePlan { s.comma( self.order_by.iter().map(|(expr, order)| ast::SortedColumn { - expr: expr.clone().into_boxed(), + expr: expr.clone(), order: Some(*order), nulls: None, }), diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 0e6873561..398b31c3c 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -154,7 +154,7 @@ fn translate_in_list( program: &mut ProgramBuilder, referenced_tables: Option<&TableReferences>, lhs: &ast::Expr, - rhs: &[ast::Expr], + rhs: &[Box], not: bool, condition_metadata: ConditionMetadata, resolver: &Resolver, @@ -1633,7 +1633,9 @@ pub fn translate_expr( ); } - if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = args[1] { + if let ast::Expr::Literal(ast::Literal::Numeric(ref value)) = + args[1].as_ref() + { if let Ok(probability) = value.parse::() { if !(0.0..=1.0).contains(&probability) { crate::bail_parse_error!( @@ -2543,7 +2545,7 @@ fn translate_like_base( /// Returns the target register for the function. fn translate_function( program: &mut ProgramBuilder, - args: &[ast::Expr], + args: &[Box], referenced_tables: Option<&TableReferences>, resolver: &Resolver, target_register: usize, @@ -2669,7 +2671,7 @@ pub fn unwrap_parens_owned(expr: ast::Expr) -> Result<(ast::Expr, usize)> { ast::Expr::Parenthesized(mut exprs) => match exprs.len() { 1 => { paren_count += 1; - let (expr, count) = unwrap_parens_owned(exprs.pop().unwrap().clone())?; + let (expr, count) = unwrap_parens_owned(*exprs.pop().unwrap().clone())?; paren_count += count; Ok((expr, paren_count)) } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 4795f3fee..0a524f348 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -85,7 +85,7 @@ pub fn init_group_by<'a>( group_by: &'a GroupBy, plan: &SelectPlan, result_columns: &'a [ResultSetColumn], - order_by: &'a [(ast::Expr, ast::SortOrder)], + order_by: &'a [(Box, ast::SortOrder)], ) -> Result<()> { collect_non_aggregate_expressions( &mut t_ctx.non_aggregate_expressions, @@ -238,13 +238,13 @@ fn collect_non_aggregate_expressions<'a>( group_by: &'a GroupBy, plan: &SelectPlan, root_result_columns: &'a [ResultSetColumn], - order_by: &'a [(ast::Expr, ast::SortOrder)], + order_by: &'a [(Box, ast::SortOrder)], ) -> Result<()> { let mut result_columns = Vec::new(); for expr in root_result_columns .iter() .map(|col| &col.expr) - .chain(order_by.iter().map(|(e, _)| e)) + .chain(order_by.iter().map(|(e, _)| e.as_ref())) .chain(group_by.having.iter().flatten()) { collect_result_columns(expr, plan, &mut result_columns)?; diff --git a/core/translate/insert.rs b/core/translate/insert.rs index b7560ceb5..3e12204f4 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -100,7 +100,7 @@ pub fn translate_insert( let root_page = btree_table.root_page; - let mut values: Option> = None; + let mut values: Option>> = None; let inserting_multiple_rows = match &mut body { InsertBody::Select(select, _) => match &mut select.body.select { // TODO see how to avoid clone @@ -110,10 +110,11 @@ pub fn translate_insert( } let mut param_idx = 1; for expr in values_expr.iter_mut().flat_map(|v| v.iter_mut()) { - match expr { + match expr.as_mut() { Expr::Id(name) => { if name.is_double_quoted() { - *expr = Expr::Literal(ast::Literal::String(name.to_string())); + *expr = + Expr::Literal(ast::Literal::String(format!("{name}"))).into(); } else { // an INSERT INTO ... VALUES (...) cannot reference columns crate::bail_parse_error!("no such column: {name}"); @@ -837,7 +838,7 @@ fn translate_rows_multiple<'short, 'long: 'short>( #[allow(clippy::too_many_arguments)] fn translate_rows_single( program: &mut ProgramBuilder, - value: &[Expr], + value: &[Box], insertion: &Insertion, resolver: &Resolver, ) -> Result<()> { diff --git a/core/translate/optimizer/lift_common_subexpressions.rs b/core/translate/optimizer/lift_common_subexpressions.rs index 6da7c534a..a66a8ab1e 100644 --- a/core/translate/optimizer/lift_common_subexpressions.rs +++ b/core/translate/optimizer/lift_common_subexpressions.rs @@ -104,7 +104,7 @@ pub(crate) fn lift_common_subexpressions_from_binary_or_terms( // If we unwrapped parentheses before, let's add them back. let mut top_level_expr = rebuild_and_expr_from_list(conjunct_list_for_or_branch); while num_unwrapped_parens > 0 { - top_level_expr = Expr::Parenthesized(vec![top_level_expr]); + top_level_expr = Expr::Parenthesized(vec![top_level_expr.into()]); num_unwrapped_parens -= 1; } new_or_operands_for_original_term.push(top_level_expr); @@ -246,11 +246,13 @@ mod tests { let or_expr = Expr::Binary( Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), x_expr.clone(), b_expr.clone()], - )])), + ) + .into()])), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), y_expr.clone(), b_expr.clone()], - )])), + ) + .into()])), ); let mut where_clause = vec![WhereTerm { @@ -273,9 +275,9 @@ mod tests { assert_eq!( nonconsumed_terms[0].expr, Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr.clone()])), + Box::new(ast::Expr::Parenthesized(vec![x_expr.clone().into()])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr.clone()])) + Box::new(ast::Expr::Parenthesized(vec![y_expr.clone().into()])) ) ); assert_eq!(nonconsumed_terms[1].expr, a_expr); @@ -340,16 +342,19 @@ mod tests { Box::new(Expr::Binary( Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), x_expr.clone()], - )])), + ) + .into()])), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), y_expr.clone()], - )])), + ) + .into()])), )), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), z_expr.clone()], - )])), + ) + .into()])), ); let mut where_clause = vec![WhereTerm { @@ -372,12 +377,12 @@ mod tests { nonconsumed_terms[0].expr, Expr::Binary( Box::new(Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr])), + Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr])), + Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])), )), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![z_expr])), + Box::new(ast::Expr::Parenthesized(vec![z_expr.into()])), ) ); assert_eq!(nonconsumed_terms[1].expr, a_expr); @@ -414,9 +419,9 @@ mod tests { ); let or_expr = Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr])), + Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr])), + Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])), ); let mut where_clause = vec![WhereTerm { @@ -479,11 +484,13 @@ mod tests { let or_expr = Expr::Binary( Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), x_expr.clone()], - )])), + ) + .into()])), Operator::Or, Box::new(ast::Expr::Parenthesized(vec![rebuild_and_expr_from_list( vec![a_expr.clone(), y_expr.clone()], - )])), + ) + .into()])), ); let mut where_clause = vec![WhereTerm { @@ -503,9 +510,9 @@ mod tests { assert_eq!( nonconsumed_terms[0].expr, Expr::Binary( - Box::new(ast::Expr::Parenthesized(vec![x_expr])), + Box::new(ast::Expr::Parenthesized(vec![x_expr.into()])), Operator::Or, - Box::new(ast::Expr::Parenthesized(vec![y_expr])) + Box::new(ast::Expr::Parenthesized(vec![y_expr.into()])) ) ); assert_eq!( diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index 005483dec..8502ca005 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -186,7 +186,7 @@ fn optimize_table_access( table_references: &mut TableReferences, available_indexes: &HashMap>>, where_clause: &mut [WhereTerm], - order_by: &mut Vec<(ast::Expr, SortOrder)>, + order_by: &mut Vec<(Box, SortOrder)>, group_by: &mut Option, ) -> Result>> { let access_methods_arena = RefCell::new(Vec::new()); diff --git a/core/translate/optimizer/order.rs b/core/translate/optimizer/order.rs index 739fc12b2..b7b3c4edc 100644 --- a/core/translate/optimizer/order.rs +++ b/core/translate/optimizer/order.rs @@ -71,7 +71,7 @@ impl OrderTarget { /// TODO: this does not currently handle the case where we definitely cannot eliminate /// the ORDER BY sorter, but we could still eliminate the GROUP BY sorter. pub fn compute_order_target( - order_by: &mut Vec<(ast::Expr, SortOrder)>, + order_by: &mut Vec<(Box, SortOrder)>, group_by_opt: Option<&mut GroupBy>, ) -> Option { match (order_by.is_empty(), group_by_opt) { @@ -79,7 +79,7 @@ pub fn compute_order_target( (true, None) => None, // Only ORDER BY - we would like the joined result rows to be in the order specified by the ORDER BY (false, None) => OrderTarget::maybe_from_iterator( - order_by.iter().map(|(expr, order)| (expr, *order)), + order_by.iter().map(|(expr, order)| (expr.as_ref(), *order)), EliminatesSortBy::Order, ), // Only GROUP BY - we would like the joined result rows to be in the order specified by the GROUP BY diff --git a/core/translate/order_by.rs b/core/translate/order_by.rs index c825592f3..4993e8010 100644 --- a/core/translate/order_by.rs +++ b/core/translate/order_by.rs @@ -36,7 +36,7 @@ pub fn init_order_by( program: &mut ProgramBuilder, t_ctx: &mut TranslateCtx, result_columns: &[ResultSetColumn], - order_by: &[(ast::Expr, SortOrder)], + order_by: &[(Box, SortOrder)], referenced_tables: &TableReferences, ) -> Result<()> { let sort_cursor = program.alloc_cursor_id(CursorType::Sorter); @@ -55,7 +55,7 @@ pub fn init_order_by( */ let collations = order_by .iter() - .map(|(expr, _)| match expr { + .map(|(expr, _)| match expr.as_ref() { ast::Expr::Collate(_, collation_name) => { CollationSeq::new(collation_name.as_str()).map(Some) } @@ -324,7 +324,7 @@ pub struct OrderByRemapping { /// /// If any result columns can be skipped, this returns list of 2-tuples of (SkippedResultColumnIndex: usize, ResultColumnIndexInOrderBySorter: usize) pub fn order_by_deduplicate_result_columns( - order_by: &[(ast::Expr, SortOrder)], + order_by: &[(Box, SortOrder)], result_columns: &[ResultSetColumn], ) -> Vec { let mut result_column_remapping: Vec = Vec::new(); diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 9dcbb669b..eba50ce89 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -288,7 +288,7 @@ pub struct SelectPlan { /// group by clause pub group_by: Option, /// order by clause - pub order_by: Vec<(ast::Expr, SortOrder)>, + pub order_by: Vec<(Box, SortOrder)>, /// all the aggregates collected from the result columns, order by, and (TODO) having clauses pub aggregates: Vec, /// limit clause @@ -376,7 +376,7 @@ pub struct DeletePlan { /// where clause split into a vec at 'AND' boundaries. pub where_clause: Vec, /// order by clause - pub order_by: Vec<(ast::Expr, SortOrder)>, + pub order_by: Vec<(Box, SortOrder)>, /// limit clause pub limit: Option, /// offset clause @@ -391,9 +391,9 @@ pub struct DeletePlan { pub struct UpdatePlan { pub table_references: TableReferences, // (colum index, new value) pairs - pub set_clauses: Vec<(usize, ast::Expr)>, + pub set_clauses: Vec<(usize, Box)>, pub where_clause: Vec, - pub order_by: Vec<(ast::Expr, SortOrder)>, + pub order_by: Vec<(Box, SortOrder)>, pub limit: Option, pub offset: Option, // TODO: optional RETURNING clause diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 34287a713..43f012875 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -75,7 +75,7 @@ pub fn resolve_aggregates( } aggs.push(Aggregate { func: f, - args: args.to_vec(), + args: args.iter().map(|arg| *arg.clone()).collect(), original_expr: expr.clone(), distinctness, }); @@ -411,7 +411,7 @@ fn parse_table( vtab_predicates: &mut Vec, qualified_name: &QualifiedName, maybe_alias: Option<&As>, - args: &[Expr], + args: &[Box], connection: &Arc, ) -> Result<()> { let normalized_qualified_name = normalize_ident(qualified_name.name.as_str()); @@ -547,7 +547,7 @@ fn parse_table( } fn transform_args_into_where_terms( - args: &[Expr], + args: &[Box], internal_id: TableInternalId, predicates: &mut Vec, table: &Table, @@ -567,7 +567,7 @@ fn transform_args_into_where_terms( column: i, is_rowid_alias: col.is_rowid_alias, }; - let expr = match arg_expr { + let expr = match arg_expr.as_ref() { Expr::Literal(Null) => Expr::IsNull(Box::new(column_expr)), other => Expr::Binary( column_expr.into(), diff --git a/core/translate/select.rs b/core/translate/select.rs index 580312077..d276471aa 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -376,7 +376,8 @@ fn prepare_one_select_plan( // COUNT() case vec![ast::Expr::Literal(ast::Literal::Numeric( "1".to_string(), - ))] + )) + .into()] } (true, _) => crate::bail_parse_error!( "Aggregate function {} requires arguments", @@ -387,7 +388,7 @@ fn prepare_one_select_plan( let agg = Aggregate { func: f, - args: agg_args.to_vec(), + args: agg_args.iter().map(|arg| *arg.clone()).collect(), original_expr: *expr.clone(), distinctness, }; @@ -447,7 +448,10 @@ fn prepare_one_select_plan( } else { let agg = Aggregate { func: AggFunc::External(f.func.clone().into()), - args: args.to_vec(), + args: args + .iter() + .map(|arg| *arg.clone()) + .collect(), original_expr: *expr.clone(), distinctness, }; @@ -567,7 +571,7 @@ fn prepare_one_select_plan( plan.group_by = Some(GroupBy { sort_order: Some((0..group_by.exprs.len()).map(|_| SortOrder::Asc).collect()), - exprs: group_by.exprs.to_vec(), + exprs: group_by.exprs.iter().map(|expr| *expr.clone()).collect(), having: if let Some(having) = group_by.having { let mut predicates = vec![]; break_predicate_at_and_boundaries(&having, &mut predicates); @@ -613,7 +617,7 @@ fn prepare_one_select_plan( )?; resolve_aggregates(schema, &o.expr, &mut plan.aggregates)?; - key.push((*o.expr, o.order.unwrap_or(ast::SortOrder::Asc))); + key.push((o.expr, o.order.unwrap_or(ast::SortOrder::Asc))); } plan.order_by = key; @@ -647,7 +651,10 @@ fn prepare_one_select_plan( contains_constant_false_condition: false, query_destination, distinctness: Distinctness::NonDistinct, - values: values.iter().map(|values| values.to_vec()).collect(), + values: values + .iter() + .map(|values| values.iter().map(|value| *value.clone()).collect()) + .collect(), }; Ok(plan) diff --git a/core/translate/update.rs b/core/translate/update.rs index a2e08e9fc..f94ebe118 100644 --- a/core/translate/update.rs +++ b/core/translate/update.rs @@ -175,7 +175,7 @@ pub fn prepare_update_plan( let values = match set.expr.as_ref() { Expr::Parenthesized(vals) => vals.clone(), - expr => vec![expr.clone()], + expr => vec![expr.clone().into()], }; if set.col_names.len() != values.len() { @@ -213,7 +213,7 @@ pub fn prepare_update_plan( let order_by = body .order_by .iter() - .map(|o| (*o.expr.clone(), o.order.unwrap_or(SortOrder::Asc))) + .map(|o| (o.expr.clone(), o.order.unwrap_or(SortOrder::Asc))) .collect(); // Sqlite determines we should create an ephemeral table if we do not have a FROM clause diff --git a/core/util.rs b/core/util.rs index 81bfdedb0..9b7d9e2c8 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1190,8 +1190,8 @@ pub fn parse_pragma_bool(expr: &Expr) -> Result { } /// Extract column name from an expression (e.g., for SELECT clauses) -pub fn extract_column_name_from_expr(expr: &ast::Expr) -> Option { - match expr { +pub fn extract_column_name_from_expr(expr: impl AsRef) -> Option { + match expr.as_ref() { ast::Expr::Id(name) => Some(name.as_str().to_string()), ast::Expr::Qualified(_, name) => Some(name.as_str().to_string()), _ => None, @@ -1435,7 +1435,7 @@ pub mod tests { let func1 = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: None, - args: vec![Expr::Id(Name::Ident("x".to_string()))], + args: vec![Expr::Id(Name::Ident("x".to_string())).into()], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1445,7 +1445,7 @@ pub mod tests { let func2 = Expr::FunctionCall { name: Name::Ident("sum".to_string()), distinctness: None, - args: vec![Expr::Id(Name::Ident("x".to_string()))], + args: vec![Expr::Id(Name::Ident("x".to_string())).into()], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1457,7 +1457,7 @@ pub mod tests { let func3 = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: Some(ast::Distinctness::Distinct), - args: vec![Expr::Id(Name::Ident("x".to_string()))], + args: vec![Expr::Id(Name::Ident("x".to_string())).into()], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1472,7 +1472,7 @@ pub mod tests { let sum = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: None, - args: vec![Expr::Id(Name::Ident("x".to_string()))], + args: vec![Expr::Id(Name::Ident("x".to_string())).into()], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1482,7 +1482,7 @@ pub mod tests { let sum_distinct = Expr::FunctionCall { name: Name::Ident("SUM".to_string()), distinctness: Some(ast::Distinctness::Distinct), - args: vec![Expr::Id(Name::Ident("x".to_string()))], + args: vec![Expr::Id(Name::Ident("x".to_string())).into()], order_by: vec![], filter_over: FunctionTail { filter_clause: None, @@ -1513,7 +1513,8 @@ pub mod tests { Box::new(Expr::Literal(Literal::Numeric("683".to_string()))), Add, Box::new(Expr::Literal(Literal::Numeric("799.0".to_string()))), - )]); + ) + .into()]); let expr2 = Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("799".to_string()))), Add, @@ -1527,7 +1528,8 @@ pub mod tests { Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), Add, Box::new(Expr::Literal(Literal::Numeric("7".to_string()))), - )]); + ) + .into()]); let expr8 = Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("6".to_string()))), Add, diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 06a74f9c4..ba1b2bc85 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -287,23 +287,6 @@ pub enum Stmt { }, } -impl Stmt { - pub fn attach(expr: Expr, db_name: Expr, key: Option) -> Stmt { - Stmt::Attach { - expr: Box::new(expr), - db_name: Box::new(db_name), - key: key.map(Box::new), - } - } - - pub fn vacuum(name: Option, into: Option) -> Stmt { - Stmt::Vacuum { - name, - into: into.map(Box::new), - } - } -} - #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -372,7 +355,7 @@ pub enum Expr { /// operand base: Option>, /// `WHEN` condition `THEN` result - when_then_pairs: Vec<(Expr, Expr)>, + when_then_pairs: Vec<(Box, Box)>, /// `ELSE` result else_expr: Option>, }, @@ -396,7 +379,7 @@ pub enum Expr { /// `DISTINCT` distinctness: Option, /// arguments - args: Vec, + args: Vec>, /// `ORDER BY` order_by: Vec, /// `FILTER` @@ -436,7 +419,7 @@ pub enum Expr { /// `NOT` not: bool, /// values - rhs: Vec, + rhs: Vec>, }, /// `IN` subselect InSelect { @@ -456,7 +439,7 @@ pub enum Expr { /// table name rhs: QualifiedName, /// table function arguments - args: Vec, + args: Vec>, }, /// `IS NULL` IsNull(Box), @@ -480,7 +463,7 @@ pub enum Expr { /// `NOT NULL` or `NOTNULL` NotNull(Box), /// Parenthesized subexpression - Parenthesized(Vec), + Parenthesized(Vec>), /// Qualified name Qualified(Name, Name), /// `RAISE` function call @@ -493,105 +476,6 @@ pub enum Expr { Variable(String), } -impl Expr { - pub fn into_boxed(self) -> Box { - Box::new(self) - } - - pub fn unary(operator: UnaryOperator, expr: Expr) -> Expr { - Expr::Unary(operator, Box::new(expr)) - } - - pub fn binary(lhs: Expr, operator: Operator, rhs: Expr) -> Expr { - Expr::Binary(Box::new(lhs), operator, Box::new(rhs)) - } - - pub fn not_null(expr: Expr) -> Expr { - Expr::NotNull(Box::new(expr)) - } - - pub fn between(lhs: Expr, not: bool, start: Expr, end: Expr) -> Expr { - Expr::Between { - lhs: Box::new(lhs), - not, - start: Box::new(start), - end: Box::new(end), - } - } - - pub fn in_select(lhs: Expr, not: bool, select: Select) -> Expr { - Expr::InSelect { - lhs: Box::new(lhs), - not, - rhs: select, - } - } - - pub fn in_list(lhs: Expr, not: bool, rhs: Vec) -> Expr { - Expr::InList { - lhs: Box::new(lhs), - not, - rhs, - } - } - - pub fn in_table(lhs: Expr, not: bool, rhs: QualifiedName, args: Vec) -> Expr { - Expr::InTable { - lhs: Box::new(lhs), - not, - rhs, - args, - } - } - - pub fn like( - lhs: Expr, - not: bool, - operator: LikeOperator, - rhs: Expr, - escape: Option, - ) -> Expr { - Expr::Like { - lhs: Box::new(lhs), - not, - op: operator, - rhs: Box::new(rhs), - escape: escape.map(Box::new), - } - } - - pub fn is_null(expr: Expr) -> Expr { - Expr::IsNull(Box::new(expr)) - } - - pub fn collate(expr: Expr, name: Name) -> Expr { - Expr::Collate(Box::new(expr), name) - } - - pub fn cast(expr: Expr, type_name: Option) -> Expr { - Expr::Cast { - expr: Box::new(expr), - type_name, - } - } - - pub fn case( - base: Option, - when_then_pairs: Vec<(Expr, Expr)>, - else_expr: Option, - ) -> Expr { - Expr::Case { - base: base.map(Box::new), - when_then_pairs, - else_expr: else_expr.map(Box::new), - } - } - - pub fn raise(resolve_type: ResolveType, expr: Option) -> Expr { - Expr::Raise(resolve_type, expr.map(Box::new)) - } -} - /// SQL literal #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -798,7 +682,7 @@ pub enum OneSelect { window_clause: Vec, }, /// `VALUES` - Values(Vec>), + Values(Vec>>), } /// `SELECT` ... `FROM` clause @@ -866,7 +750,7 @@ pub enum SelectTable { /// table Table(QualifiedName, Option, Option), /// table function call - TableCall(QualifiedName, Vec, Option), + TableCall(QualifiedName, Vec>, Option), /// `SELECT` subquery Select(Select, Option), /// subquery @@ -920,7 +804,7 @@ pub enum JoinConstraint { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct GroupBy { /// expressions - pub exprs: Vec, + pub exprs: Vec>, /// `HAVING` pub having: Option>, // HAVING clause on a non-aggregate query } @@ -1667,7 +1551,7 @@ pub struct Window { /// base window name pub base: Option, /// `PARTITION BY` - pub partition_by: Vec, + pub partition_by: Vec>, /// `ORDER BY` pub order_by: Vec, /// frame spec diff --git a/parser/src/parser.rs b/parser/src/parser.rs index c21a8fd7b..124058fc9 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -953,11 +953,11 @@ impl<'a> Parser<'a> { } } - fn parse_signed(&mut self) -> Result { + fn parse_signed(&mut self) -> Result> { peek_expect!(self, TK_FLOAT, TK_INTEGER, TK_PLUS, TK_MINUS); let expr = self.parse_expr_operand()?; - match &expr { + match expr.as_ref() { Expr::Unary(_, inner) => match inner.as_ref() { Expr::Literal(Literal::Numeric(_)) => Ok(expr), _ => Err(Error::Custom( @@ -998,14 +998,11 @@ impl<'a> Parser<'a> { let first_size = self.parse_signed()?; let tok = eat_expect!(self, TK_RP, TK_COMMA); match tok.token_type.unwrap() { - TK_RP => Some(TypeSize::MaxSize(Box::new(first_size))), + TK_RP => Some(TypeSize::MaxSize(first_size)), TK_COMMA => { let second_size = self.parse_signed()?; eat_expect!(self, TK_RP); - Some(TypeSize::TypeSize( - Box::new(first_size), - Box::new(second_size), - )) + Some(TypeSize::TypeSize(first_size, second_size)) } _ => unreachable!(), } @@ -1101,7 +1098,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_WHERE); let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); - Ok(Some(Box::new(expr))) + Ok(Some(expr)) } fn parse_frame_opt(&mut self) -> Result> { @@ -1144,7 +1141,7 @@ impl<'a> Parser<'a> { FrameBound::CurrentRow } _ => { - let expr = Box::new(self.parse_expr(0)?); + let expr = self.parse_expr(0)?; let tok = eat_expect!(self, TK_PRECEDING, TK_FOLLOWING); match tok.token_type.unwrap() { TK_PRECEDING => FrameBound::Preceding(expr), @@ -1169,7 +1166,7 @@ impl<'a> Parser<'a> { FrameBound::CurrentRow } _ => { - let expr = Box::new(self.parse_expr(0)?); + let expr = self.parse_expr(0)?; let tok = eat_expect!(self, TK_PRECEDING, TK_FOLLOWING); match tok.token_type.unwrap() { TK_PRECEDING => FrameBound::Preceding(expr), @@ -1288,7 +1285,7 @@ impl<'a> Parser<'a> { } } - fn parse_expr_operand(&mut self) -> Result { + fn parse_expr_operand(&mut self) -> Result> { let tok = peek_expect!( self, TK_LP, @@ -1319,34 +1316,40 @@ impl<'a> Parser<'a> { TK_WITH | TK_SELECT | TK_VALUES => { let select = self.parse_select()?; eat_expect!(self, TK_RP); - Ok(Expr::Subquery(select)) + Ok(Box::new(Expr::Subquery(select))) } _ => { let exprs = self.parse_nexpr_list()?; eat_expect!(self, TK_RP); - Ok(Expr::Parenthesized(exprs)) + Ok(Box::new(Expr::Parenthesized(exprs))) } } } TK_NULL => { eat_assert!(self, TK_NULL); - Ok(Expr::Literal(Literal::Null)) + Ok(Box::new(Expr::Literal(Literal::Null))) } TK_BLOB => { let tok = eat_assert!(self, TK_BLOB); - Ok(Expr::Literal(Literal::Blob(from_bytes(tok.value)))) + Ok(Box::new(Expr::Literal(Literal::Blob(from_bytes( + tok.value, + ))))) } TK_FLOAT => { let tok = eat_assert!(self, TK_FLOAT); - Ok(Expr::Literal(Literal::Numeric(from_bytes(tok.value)))) + Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes( + tok.value, + ))))) } TK_INTEGER => { let tok = eat_assert!(self, TK_INTEGER); - Ok(Expr::Literal(Literal::Numeric(from_bytes(tok.value)))) + Ok(Box::new(Expr::Literal(Literal::Numeric(from_bytes( + tok.value, + ))))) } TK_VARIABLE => { let tok = eat_assert!(self, TK_VARIABLE); - Ok(Expr::Variable(from_bytes(tok.value))) + Ok(Box::new(Expr::Variable(from_bytes(tok.value)))) } TK_CAST => { eat_assert!(self, TK_CAST); @@ -1355,16 +1358,19 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_AS); let typ = self.parse_type()?; eat_expect!(self, TK_RP); - Ok(Expr::cast(expr, typ)) + Ok(Box::new(Expr::Cast { + expr, + type_name: typ, + })) } TK_CTIME_KW => { let tok = eat_assert!(self, TK_CTIME_KW); if b"CURRENT_DATE".eq_ignore_ascii_case(tok.value) { - Ok(Expr::Literal(Literal::CurrentDate)) + Ok(Box::new(Expr::Literal(Literal::CurrentDate))) } else if b"CURRENT_TIME".eq_ignore_ascii_case(tok.value) { - Ok(Expr::Literal(Literal::CurrentTime)) + Ok(Box::new(Expr::Literal(Literal::CurrentTime))) } else if b"CURRENT_TIMESTAMP".eq_ignore_ascii_case(tok.value) { - Ok(Expr::Literal(Literal::CurrentTimestamp)) + Ok(Box::new(Expr::Literal(Literal::CurrentTimestamp))) } else { unreachable!() } @@ -1372,29 +1378,29 @@ impl<'a> Parser<'a> { TK_NOT => { eat_assert!(self, TK_NOT); let expr = self.parse_expr(2)?; // NOT precedence is 2 - Ok(Expr::unary(UnaryOperator::Not, expr)) + Ok(Box::new(Expr::Unary(UnaryOperator::Not, expr))) } TK_BITNOT => { eat_assert!(self, TK_BITNOT); let expr = self.parse_expr(11)?; // BITNOT precedence is 11 - Ok(Expr::unary(UnaryOperator::BitwiseNot, expr)) + Ok(Box::new(Expr::Unary(UnaryOperator::BitwiseNot, expr))) } TK_PLUS => { eat_assert!(self, TK_PLUS); let expr = self.parse_expr(11)?; // PLUS precedence is 11 - Ok(Expr::unary(UnaryOperator::Positive, expr)) + Ok(Box::new(Expr::Unary(UnaryOperator::Positive, expr))) } TK_MINUS => { eat_assert!(self, TK_MINUS); let expr = self.parse_expr(11)?; // MINUS precedence is 11 - Ok(Expr::unary(UnaryOperator::Negative, expr)) + Ok(Box::new(Expr::Unary(UnaryOperator::Negative, expr))) } TK_EXISTS => { eat_assert!(self, TK_EXISTS); eat_expect!(self, TK_LP); let select = self.parse_select()?; eat_expect!(self, TK_RP); - Ok(Expr::Exists(select)) + Ok(Box::new(Expr::Exists(select))) } TK_CASE => { eat_assert!(self, TK_CASE); @@ -1433,7 +1439,11 @@ impl<'a> Parser<'a> { }; eat_expect!(self, TK_END); - Ok(Expr::case(base, when_then_pairs, else_expr)) + Ok(Box::new(Expr::Case { + base, + when_then_pairs, + else_expr, + })) } TK_RAISE => { eat_assert!(self, TK_RAISE); @@ -1455,7 +1465,7 @@ impl<'a> Parser<'a> { }; eat_expect!(self, TK_RP); - Ok(Expr::raise(resolve, expr)) + Ok(Box::new(Expr::Raise(resolve, expr))) } _ => { let can_be_lit_str = tok.token_type == Some(TK_STRING); @@ -1480,10 +1490,10 @@ impl<'a> Parser<'a> { TK_STAR => { eat_assert!(self, TK_STAR); eat_expect!(self, TK_RP); - return Ok(Expr::FunctionCallStar { + return Ok(Box::new(Expr::FunctionCallStar { name, filter_over: self.parse_filter_over()?, - }); + })); } _ => { let distinct = self.parse_distinct()?; @@ -1491,13 +1501,13 @@ impl<'a> Parser<'a> { let order_by = self.parse_order_by()?; eat_expect!(self, TK_RP); let filter_over = self.parse_filter_over()?; - return Ok(Expr::FunctionCall { + return Ok(Box::new(Expr::FunctionCall { name, distinctness: distinct, args: exprs, order_by, filter_over, - }); + })); } } } else { @@ -1521,24 +1531,28 @@ impl<'a> Parser<'a> { if let Some(second_name) = second_name { if let Some(third_name) = third_name { - Ok(Expr::DoublyQualified(name, second_name, third_name)) + Ok(Box::new(Expr::DoublyQualified( + name, + second_name, + third_name, + ))) } else { - Ok(Expr::Qualified(name, second_name)) + Ok(Box::new(Expr::Qualified(name, second_name))) } } else if can_be_lit_str { - Ok(Expr::Literal(match name { + Ok(Box::new(Expr::Literal(match name { Name::Quoted(s) => Literal::String(s), Name::Ident(s) => Literal::String(s), - })) + }))) } else { - Ok(Expr::Id(name)) + Ok(Box::new(Expr::Id(name))) } } } } #[allow(clippy::vec_box)] - fn parse_expr_list(&mut self) -> Result> { + fn parse_expr_list(&mut self) -> Result>> { let mut exprs = vec![]; while let Some(tok) = self.peek()? { match tok.token_type.unwrap().fallback_id_if_ok() { @@ -1560,7 +1574,7 @@ impl<'a> Parser<'a> { Ok(exprs) } - fn parse_expr(&mut self, precedence: u8) -> Result { + fn parse_expr(&mut self, precedence: u8) -> Result> { let mut result = self.parse_expr_operand()?; loop { @@ -1580,28 +1594,44 @@ impl<'a> Parser<'a> { not = true; } - let expr = match tok.token_type.unwrap() { + result = match tok.token_type.unwrap() { TK_NULL => { // special case `NOT NULL` debug_assert!(not); // FIXME: not always true because of current_token_precedence eat_assert!(self, TK_NULL); - Expr::not_null(result) + Box::new(Expr::NotNull(result)) } TK_OR => { eat_assert!(self, TK_OR); - Expr::binary(result, Operator::Or, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Or, + self.parse_expr(pre + 1)?, + )) } TK_AND => { eat_assert!(self, TK_AND); - Expr::binary(result, Operator::And, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::And, + self.parse_expr(pre + 1)?, + )) } TK_EQ => { eat_assert!(self, TK_EQ); - Expr::binary(result, Operator::Equals, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Equals, + self.parse_expr(pre + 1)?, + )) } TK_NE => { eat_assert!(self, TK_NE); - Expr::binary(result, Operator::NotEquals, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::NotEquals, + self.parse_expr(pre + 1)?, + )) } TK_IS => { eat_assert!(self, TK_IS); @@ -1633,14 +1663,19 @@ impl<'a> Parser<'a> { } }; - Expr::binary(result, op, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary(result, op, self.parse_expr(pre + 1)?)) } TK_BETWEEN => { eat_assert!(self, TK_BETWEEN); let start = self.parse_expr(pre)?; eat_expect!(self, TK_AND); let end = self.parse_expr(pre)?; - Expr::between(result, not, start, end) + Box::new(Expr::Between { + lhs: result, + not, + start, + end, + }) } TK_IN => { eat_assert!(self, TK_IN); @@ -1653,12 +1688,20 @@ impl<'a> Parser<'a> { TK_SELECT | TK_WITH | TK_VALUES => { let select = self.parse_select()?; eat_expect!(self, TK_RP); - Expr::in_select(result, not, select) + Box::new(Expr::InSelect { + lhs: result, + not, + rhs: select, + }) } _ => { let exprs = self.parse_expr_list()?; eat_expect!(self, TK_RP); - Expr::in_list(result, not, exprs) + Box::new(Expr::InList { + lhs: result, + not, + rhs: exprs, + }) } } } @@ -1672,7 +1715,13 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_RP); } } - Expr::in_table(result, not, name, exprs) + + Box::new(Expr::InTable { + lhs: result, + not, + rhs: name, + args: exprs, + }) } } } @@ -1706,72 +1755,134 @@ impl<'a> Parser<'a> { None }; - Expr::like(result, not, op, expr, escape) + Box::new(Expr::Like { + lhs: result, + not, + op, + rhs: expr, + escape, + }) } TK_ISNULL => { eat_assert!(self, TK_ISNULL); - Expr::is_null(result) + Box::new(Expr::IsNull(result)) } TK_NOTNULL => { eat_assert!(self, TK_NOTNULL); - Expr::not_null(result) + Box::new(Expr::NotNull(result)) } TK_LT => { eat_assert!(self, TK_LT); - Expr::binary(result, Operator::Less, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Less, + self.parse_expr(pre + 1)?, + )) } TK_GT => { eat_assert!(self, TK_GT); - Expr::binary(result, Operator::Greater, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Greater, + self.parse_expr(pre + 1)?, + )) } TK_LE => { eat_assert!(self, TK_LE); - Expr::binary(result, Operator::LessEquals, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::LessEquals, + self.parse_expr(pre + 1)?, + )) } TK_GE => { eat_assert!(self, TK_GE); - Expr::binary(result, Operator::GreaterEquals, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::GreaterEquals, + self.parse_expr(pre + 1)?, + )) } TK_ESCAPE => unreachable!(), TK_BITAND => { eat_assert!(self, TK_BITAND); - Expr::binary(result, Operator::BitwiseAnd, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::BitwiseAnd, + self.parse_expr(pre + 1)?, + )) } TK_BITOR => { eat_assert!(self, TK_BITOR); - Expr::binary(result, Operator::BitwiseOr, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::BitwiseOr, + self.parse_expr(pre + 1)?, + )) } TK_LSHIFT => { eat_assert!(self, TK_LSHIFT); - Expr::binary(result, Operator::LeftShift, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::LeftShift, + self.parse_expr(pre + 1)?, + )) } TK_RSHIFT => { eat_assert!(self, TK_RSHIFT); - Expr::binary(result, Operator::RightShift, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::RightShift, + self.parse_expr(pre + 1)?, + )) } TK_PLUS => { eat_assert!(self, TK_PLUS); - Expr::binary(result, Operator::Add, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Add, + self.parse_expr(pre + 1)?, + )) } TK_MINUS => { eat_assert!(self, TK_MINUS); - Expr::binary(result, Operator::Subtract, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Subtract, + self.parse_expr(pre + 1)?, + )) } TK_STAR => { eat_assert!(self, TK_STAR); - Expr::binary(result, Operator::Multiply, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Multiply, + self.parse_expr(pre + 1)?, + )) } TK_SLASH => { eat_assert!(self, TK_SLASH); - Expr::binary(result, Operator::Divide, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Divide, + self.parse_expr(pre + 1)?, + )) } TK_REM => { eat_assert!(self, TK_REM); - Expr::binary(result, Operator::Modulus, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Modulus, + self.parse_expr(pre + 1)?, + )) } TK_CONCAT => { eat_assert!(self, TK_CONCAT); - Expr::binary(result, Operator::Concat, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary( + result, + Operator::Concat, + self.parse_expr(pre + 1)?, + )) } TK_PTR => { let tok = eat_assert!(self, TK_PTR); @@ -1781,13 +1892,11 @@ impl<'a> Parser<'a> { Operator::ArrowRightShift }; - Expr::binary(result, op, self.parse_expr(pre + 1)?) + Box::new(Expr::Binary(result, op, self.parse_expr(pre + 1)?)) } - TK_COLLATE => Expr::collate(result, self.parse_collate()?.unwrap()), + TK_COLLATE => Box::new(Expr::Collate(result, self.parse_collate()?.unwrap())), _ => unreachable!(), - }; - - result = expr; + } } Ok(result) @@ -2005,10 +2114,7 @@ impl<'a> Parser<'a> { _ => None, }; - Ok(Some(GroupBy { - exprs, - having: having.map(Box::new), - })) + Ok(Some(GroupBy { exprs, having })) } fn parse_where(&mut self) -> Result>> { @@ -2018,7 +2124,7 @@ impl<'a> Parser<'a> { TK_WHERE => { eat_assert!(self, TK_WHERE); let expr = self.parse_expr(0)?; - Ok(Some(expr.into_boxed())) + Ok(Some(expr)) } _ => Ok(None), }, @@ -2080,7 +2186,7 @@ impl<'a> Parser<'a> { TK_ON => { eat_assert!(self, TK_ON); let expr = self.parse_expr(0)?; - Ok(Some(JoinConstraint::On(expr.into_boxed()))) + Ok(Some(JoinConstraint::On(expr))) } TK_USING => { eat_assert!(self, TK_USING); @@ -2308,7 +2414,7 @@ impl<'a> Parser<'a> { let expr = self.parse_expr(0)?; let alias = self.parse_as()?; - Ok(ResultColumn::Expr(expr.into_boxed(), alias)) + Ok(ResultColumn::Expr(expr, alias)) } } } @@ -2330,7 +2436,7 @@ impl<'a> Parser<'a> { } #[allow(clippy::vec_box)] - fn parse_nexpr_list(&mut self) -> Result> { + fn parse_nexpr_list(&mut self) -> Result>> { let mut result = vec![self.parse_expr(0)?]; while let Some(tok) = self.peek()? { if tok.token_type == Some(TK_COMMA) { @@ -2428,7 +2534,7 @@ impl<'a> Parser<'a> { } fn parse_sorted_column(&mut self) -> Result { - let expr = self.parse_expr(0)?.into_boxed(); + let expr = self.parse_expr(0)?; let sort_order = self.parse_sort_order()?; let nulls = match self.peek()? { @@ -2492,7 +2598,7 @@ impl<'a> Parser<'a> { return Ok(None); } - let limit = self.parse_expr(0)?.into_boxed(); + let limit = self.parse_expr(0)?; let offset = match self.peek()? { Some(tok) => match tok.token_type.unwrap() { TK_OFFSET | TK_COMMA => { @@ -2502,8 +2608,7 @@ impl<'a> Parser<'a> { _ => None, }, _ => None, - } - .map(Box::new); + }; Ok(Some(Limit { expr: limit, @@ -2560,7 +2665,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_LP); let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); - Ok(TableConstraint::Check(expr.into_boxed())) + Ok(TableConstraint::Check(expr)) } fn parse_foreign_key_table_constraint(&mut self) -> Result { @@ -2847,7 +2952,7 @@ impl<'a> Parser<'a> { _ => None, }; - Ok(Stmt::attach(expr, db_name, key)) + Ok(Stmt::Attach { expr, db_name, key }) } fn parse_detach(&mut self) -> Result { @@ -2857,20 +2962,23 @@ impl<'a> Parser<'a> { } Ok(Stmt::Detach { - name: self.parse_expr(0)?.into_boxed(), + name: self.parse_expr(0)?, }) } fn parse_pragma_value(&mut self) -> Result { - let expr = match self.peek_no_eof()?.token_type.unwrap().fallback_id_if_ok() { + match self.peek_no_eof()?.token_type.unwrap().fallback_id_if_ok() { TK_ON | TK_DELETE | TK_DEFAULT => { let tok = eat_assert!(self, TK_ON, TK_DELETE, TK_DEFAULT); - Expr::Literal(Literal::Keyword(from_bytes(tok.value))) + Ok(Box::new(Expr::Literal(Literal::Keyword(from_bytes( + tok.value, + ))))) } - TK_ID | TK_STRING | TK_INDEXED | TK_JOIN_KW => Expr::Name(self.parse_nm()?), - _ => self.parse_signed()?, - }; - Ok(Box::new(expr)) + TK_ID | TK_STRING | TK_INDEXED | TK_JOIN_KW => { + Ok(Box::new(Expr::Name(self.parse_nm()?))) + } + _ => self.parse_signed(), + } } fn parse_pragma(&mut self) -> Result { @@ -2919,7 +3027,7 @@ impl<'a> Parser<'a> { _ => None, }; - Ok(Stmt::vacuum(name, into)) + Ok(Stmt::Vacuum { name, into }) } fn parse_term(&mut self) -> Result> { @@ -2933,7 +3041,7 @@ impl<'a> Parser<'a> { TK_CTIME_KW, ); - self.parse_expr_operand().map(Box::new) + self.parse_expr_operand() } fn parse_default_column_constraint(&mut self) -> Result { @@ -3064,7 +3172,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_LP); let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); - Ok(ColumnConstraint::Check(expr.into_boxed())) + Ok(ColumnConstraint::Check(expr)) } fn parse_ref_act(&mut self) -> Result { @@ -3185,7 +3293,7 @@ impl<'a> Parser<'a> { } eat_expect!(self, TK_LP); - let expr = self.parse_expr(0)?.into_boxed(); + let expr = self.parse_expr(0)?; eat_expect!(self, TK_RP); let typ = match self.peek()? { @@ -3421,7 +3529,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_EQ); Ok(Set { col_names: names, - expr: self.parse_expr(0)?.into_boxed(), + expr: self.parse_expr(0)?, }) } _ => { @@ -3429,7 +3537,7 @@ impl<'a> Parser<'a> { eat_expect!(self, TK_EQ); Ok(Set { col_names: vec![name], - expr: self.parse_expr(0)?.into_boxed(), + expr: self.parse_expr(0)?, }) } } @@ -3658,8 +3766,7 @@ impl<'a> Parser<'a> { Some(self.parse_expr(0)?) } _ => None, - } - .map(Box::new); + }; eat_expect!(self, TK_BEGIN); @@ -4146,9 +4253,9 @@ mod tests { select: OneSelect::Select { distinctness: None, columns: vec![ResultColumn::Expr( - Expr::Parenthesized(vec![Expr::Literal( + Box::new(Expr::Parenthesized(vec![Box::new(Expr::Literal( Literal::Numeric("1".to_owned()), - )]).into_boxed(), + ))])), None, )], from: None, @@ -4625,8 +4732,8 @@ mod tests { Box::new(Expr::Case { base: None, when_then_pairs: vec![( - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), )], else_expr: Some(Box::new(Expr::Literal(Literal::Numeric( "3".to_owned(), @@ -4658,8 +4765,8 @@ mod tests { "4".to_owned(), )))), when_then_pairs: vec![( - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), )], else_expr: Some(Box::new(Expr::Literal(Literal::Numeric( "3".to_owned(), @@ -4691,8 +4798,8 @@ mod tests { "4".to_owned(), )))), when_then_pairs: vec![( - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), )], else_expr: None, }), @@ -5019,8 +5126,8 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { @@ -5057,17 +5164,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: None, - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: None, })), @@ -5098,17 +5205,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: None, })), @@ -5139,17 +5246,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![ SortedColumn { expr: Box::new(Expr::Id(Name::Ident("test".to_owned()))), @@ -5186,17 +5293,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Rows, @@ -5232,17 +5339,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Range, @@ -5278,17 +5385,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5324,17 +5431,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5370,17 +5477,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5416,17 +5523,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5462,17 +5569,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5508,17 +5615,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5556,17 +5663,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5604,17 +5711,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5650,17 +5757,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5696,17 +5803,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -5742,17 +5849,17 @@ mod tests { name: Name::Ident("func_name".to_owned()), distinctness: Some(Distinctness::Distinct), args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], order_by: vec![], filter_over: FunctionTail { filter_clause: None, over_clause: Some(Over::Window(Window { base: Some(Name::Ident("test".to_owned())), - partition_by: vec![Expr::Id(Name::Ident( + partition_by: vec![Box::new(Expr::Id(Name::Ident( "product".to_owned(), - ))], + )))], order_by: vec![], frame_clause: Some(FrameClause{ mode: FrameMode::Groups, @@ -6216,9 +6323,9 @@ mod tests { lhs: Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), not: false, rhs: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), - Expr::Literal(Literal::Numeric("3".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), ], }), Operator::And, @@ -6255,9 +6362,9 @@ mod tests { alias: None, }, args: vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), - Expr::Literal(Literal::Numeric("3".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), ], }), Operator::And, @@ -6951,16 +7058,16 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], vec![ - Expr::Literal(Literal::Numeric("3".to_owned())), - Expr::Literal(Literal::Numeric("4".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("4".to_owned()))), ], vec![ - Expr::Literal(Literal::Numeric("5".to_owned())), - Expr::Literal(Literal::Numeric("6".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("5".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("6".to_owned()))), ], ]), compounds: vec![], @@ -7640,8 +7747,8 @@ mod tests { select: Box::new(SelectTable::TableCall( QualifiedName { db_name: None, name: Name::Ident("foo".to_owned()), alias: None }, vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], None, )), @@ -8373,8 +8480,8 @@ mod tests { table: Box::new(SelectTable::TableCall( QualifiedName { db_name: None, name: Name::Ident("bar".to_owned()), alias: None }, vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], None, )), @@ -8418,12 +8525,12 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())) + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))) ], vec![ - Expr::Literal(Literal::Numeric("3".to_owned())), - Expr::Literal(Literal::Numeric("4".to_owned())) + Box::new(Expr::Literal(Literal::Numeric("3".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("4".to_owned()))) ], ]), compounds: vec![], @@ -8547,11 +8654,11 @@ mod tests { where_clause: None, group_by: Some(GroupBy { exprs: vec![ - Expr::Binary( + Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), Operator::Equals, Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - ), + )), ], having: None, }), @@ -8585,11 +8692,11 @@ mod tests { where_clause: None, group_by: Some(GroupBy { exprs: vec![ - Expr::Binary( + Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), Operator::Equals, Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - ), + )), ], having: Some(Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), @@ -8627,11 +8734,11 @@ mod tests { where_clause: None, group_by: Some(GroupBy { exprs: vec![ - Expr::Binary( + Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), Operator::Equals, Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), - ), + )), ], having: Some(Box::new(Expr::Binary( Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), @@ -8707,7 +8814,7 @@ mod tests { window: Window { base: None, partition_by: vec![ - Expr::Id(Name::Ident("product".to_owned())), + Box::new(Expr::Id(Name::Ident("product".to_owned()))), ], order_by: vec![], frame_clause: None, @@ -8745,7 +8852,7 @@ mod tests { window: Window { base: None, partition_by: vec![ - Expr::Id(Name::Ident("product".to_owned())), + Box::new(Expr::Id(Name::Ident("product".to_owned()))), ], order_by: vec![], frame_clause: None, @@ -8756,7 +8863,7 @@ mod tests { window: Window { base: None, partition_by: vec![ - Expr::Id(Name::Ident("product_2".to_owned())), + Box::new(Expr::Id(Name::Ident("product_2".to_owned()))), ], order_by: vec![], frame_clause: None, @@ -9010,7 +9117,7 @@ mod tests { name: None, constraint: ColumnConstraint::Default( Box::new(Expr::Parenthesized(vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), ])) ), }, @@ -10242,8 +10349,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -10289,8 +10396,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -10333,8 +10440,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -10386,8 +10493,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -10457,8 +10564,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -10525,8 +10632,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -11033,8 +11140,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], @@ -11061,8 +11168,8 @@ mod tests { body: SelectBody { select: OneSelect::Values(vec![ vec![ - Expr::Literal(Literal::Numeric("1".to_owned())), - Expr::Literal(Literal::Numeric("2".to_owned())), + Box::new(Expr::Literal(Literal::Numeric("1".to_owned()))), + Box::new(Expr::Literal(Literal::Numeric("2".to_owned()))), ], ]), compounds: vec![], From aa025c979838da5ddf8816f9c6704ead4030707a Mon Sep 17 00:00:00 2001 From: pedrocarlo Date: Tue, 26 Aug 2025 10:08:59 -0300 Subject: [PATCH 70/73] fix missing functions after revert --- parser/src/ast.rs | 70 ++++++++++++++++++++ simulator-docker-runner/Dockerfile.simulator | 2 + sql_generation/model/query/predicate.rs | 2 +- 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index ba1b2bc85..a452039bb 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -476,6 +476,76 @@ pub enum Expr { Variable(String), } +impl Expr { + pub fn into_boxed(self) -> Box { + Box::new(self) + } + + pub fn unary(operator: UnaryOperator, expr: Expr) -> Expr { + Expr::Unary(operator, Box::new(expr)) + } + + pub fn binary(lhs: Expr, operator: Operator, rhs: Expr) -> Expr { + Expr::Binary(Box::new(lhs), operator, Box::new(rhs)) + } + + pub fn not_null(expr: Expr) -> Expr { + Expr::NotNull(Box::new(expr)) + } + + pub fn between(lhs: Expr, not: bool, start: Expr, end: Expr) -> Expr { + Expr::Between { + lhs: Box::new(lhs), + not, + start: Box::new(start), + end: Box::new(end), + } + } + + pub fn in_select(lhs: Expr, not: bool, select: Select) -> Expr { + Expr::InSelect { + lhs: Box::new(lhs), + not, + rhs: select, + } + } + + pub fn like( + lhs: Expr, + not: bool, + operator: LikeOperator, + rhs: Expr, + escape: Option, + ) -> Expr { + Expr::Like { + lhs: Box::new(lhs), + not, + op: operator, + rhs: Box::new(rhs), + escape: escape.map(Box::new), + } + } + + pub fn is_null(expr: Expr) -> Expr { + Expr::IsNull(Box::new(expr)) + } + + pub fn collate(expr: Expr, name: Name) -> Expr { + Expr::Collate(Box::new(expr), name) + } + + pub fn cast(expr: Expr, type_name: Option) -> Expr { + Expr::Cast { + expr: Box::new(expr), + type_name, + } + } + + pub fn raise(resolve_type: ResolveType, expr: Option) -> Expr { + Expr::Raise(resolve_type, expr.map(Box::new)) + } +} + /// SQL literal #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] diff --git a/simulator-docker-runner/Dockerfile.simulator b/simulator-docker-runner/Dockerfile.simulator index f70a1e7ad..74f579ed6 100644 --- a/simulator-docker-runner/Dockerfile.simulator +++ b/simulator-docker-runner/Dockerfile.simulator @@ -26,6 +26,7 @@ COPY stress ./stress/ COPY tests ./tests/ COPY packages ./packages/ COPY testing/sqlite_test_ext ./testing/sqlite_test_ext +COPY sql_generation ./sql_generation/ RUN cargo chef prepare --bin limbo_sim --recipe-path recipe.json # @@ -43,6 +44,7 @@ COPY --from=planner /app/macros ./macros/ COPY --from=planner /app/parser ./parser/ COPY --from=planner /app/simulator ./simulator/ COPY --from=planner /app/packages ./packages/ +COPY --from=planner /app/sql_generation ./sql_generation/ RUN cargo build --bin limbo_sim --release diff --git a/sql_generation/model/query/predicate.rs b/sql_generation/model/query/predicate.rs index 30b671d72..29f0b966c 100644 --- a/sql_generation/model/query/predicate.rs +++ b/sql_generation/model/query/predicate.rs @@ -71,7 +71,7 @@ impl Predicate { } pub fn parens(self) -> Self { - let expr = ast::Expr::Parenthesized(vec![self.0]); + let expr = ast::Expr::Parenthesized(vec![Box::new(self.0)]); Self(expr) } From caa00e31f8f60d12d690f7d8f15d54ae637e9808 Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Tue, 26 Aug 2025 20:00:13 +0530 Subject: [PATCH 71/73] Use `Cell` instead of `RefCell` because its nice --- core/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index cfe794a2e..553076375 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -464,7 +464,7 @@ impl Database { metrics: RefCell::new(ConnectionMetrics::new()), is_nested_stmt: Cell::new(false), encryption_key: RefCell::new(None), - encryption_cipher_mode: RefCell::new(None), + encryption_cipher_mode: Cell::new(None), }); self.n_connections .fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -898,7 +898,7 @@ pub struct Connection { /// Generally this is only true for ParseSchema. is_nested_stmt: Cell, encryption_key: RefCell>, - encryption_cipher_mode: RefCell>, + encryption_cipher_mode: Cell>, } impl Drop for Connection { @@ -1999,7 +1999,7 @@ impl Connection { } pub fn get_encryption_cipher_mode(&self) -> Option { - *self.encryption_cipher_mode.borrow() + self.encryption_cipher_mode.get() } // if both key and cipher are set, set encryption context on pager @@ -2008,7 +2008,7 @@ impl Connection { let Some(key) = key_ref.as_ref() else { return; }; - let Some(cipher_mode) = *self.encryption_cipher_mode.borrow() else { + let Some(cipher_mode) = self.encryption_cipher_mode.get() else { return; }; tracing::trace!("setting encryption ctx for connection"); From 4cf111e3c22676a0ea9c9fc40a1a3aa5fdd2320f Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 26 Aug 2025 14:13:42 -0400 Subject: [PATCH 72/73] Rename Go driver to `turso` to not conflict with sqlite3, rename limbo->turso --- Cargo.lock | 14 ++-- bindings/go/Cargo.toml | 4 +- bindings/go/README.md | 32 ++++---- bindings/go/build_lib.sh | 14 ++-- bindings/go/connection.go | 75 +++++++++---------- bindings/go/embedded.go | 12 +-- bindings/go/rows.go | 20 ++--- bindings/go/rs_src/lib.rs | 16 ++-- bindings/go/rs_src/rows.rs | 34 ++++----- bindings/go/rs_src/statement.rs | 44 +++++------ bindings/go/rs_src/types.rs | 28 +++---- bindings/go/stmt.go | 24 +++--- bindings/go/{limbo_test.go => turso_test.go} | 30 ++++---- bindings/go/{limbo_unix.go => turso_unix.go} | 2 +- .../go/{limbo_windows.go => turso_windows.go} | 2 +- bindings/go/types.go | 50 ++++++------- 16 files changed, 200 insertions(+), 201 deletions(-) rename bindings/go/{limbo_test.go => turso_test.go} (96%) rename bindings/go/{limbo_unix.go => turso_unix.go} (99%) rename bindings/go/{limbo_windows.go => turso_windows.go} (98%) diff --git a/Cargo.lock b/Cargo.lock index e599b2f8b..57cf52cd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2066,13 +2066,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "limbo-go" -version = "0.1.4" -dependencies = [ - "turso_core", -] - [[package]] name = "limbo_completion" version = "0.1.4" @@ -3984,6 +3977,13 @@ dependencies = [ "turso_core", ] +[[package]] +name = "turso-go" +version = "0.1.4" +dependencies = [ + "turso_core", +] + [[package]] name = "turso-java" version = "0.1.4" diff --git a/bindings/go/Cargo.toml b/bindings/go/Cargo.toml index 228aead5e..8f8a55d76 100644 --- a/bindings/go/Cargo.toml +++ b/bindings/go/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "limbo-go" +name = "turso-go" version.workspace = true authors.workspace = true edition.workspace = true @@ -8,7 +8,7 @@ repository.workspace = true publish = false [lib] -name = "_limbo_go" +name = "_turso_go" crate-type = ["cdylib"] path = "rs_src/lib.rs" diff --git a/bindings/go/README.md b/bindings/go/README.md index 72672ebfe..af74b98a4 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -1,4 +1,4 @@ -# Limbo driver for Go's `database/sql` library +# Turso driver for Go's `database/sql` library **NOTE:** this is currently __heavily__ W.I.P and is not yet in a usable state. @@ -17,7 +17,7 @@ To build with embedded library support, follow these steps: git clone https://github.com/tursodatabase/turso # Navigate to the Go bindings directory -cd limbo/bindings/go +cd turso/bindings/go # Build the library (defaults to release build) ./build_lib.sh @@ -52,34 +52,34 @@ Build the driver with the embedded library as described above, then simply impor #### Linux | MacOS -_All commands listed are relative to the bindings/go directory in the limbo repository_ +_All commands listed are relative to the bindings/go directory in the turso repository_ ``` -cargo build --package limbo-go +cargo build --package turso-go -# Your LD_LIBRARY_PATH environment variable must include limbo's `target/debug` directory +# Your LD_LIBRARY_PATH environment variable must include turso's `target/debug` directory -export LD_LIBRARY_PATH="/path/to/limbo/target/debug:$LD_LIBRARY_PATH" +export LD_LIBRARY_PATH="/path/to/turso/target/debug:$LD_LIBRARY_PATH" ``` #### Windows ``` -cargo build --package limbo-go +cargo build --package turso-go -# You must add limbo's `target/debug` directory to your PATH +# You must add turso's `target/debug` directory to your PATH # or you could built + copy the .dll to a location in your PATH # or just the CWD of your go module -cp path\to\limbo\target\debug\lib_limbo_go.dll . +cp path\to\turso\target\debug\lib_turso_go.dll . go test ``` -**Temporarily** you may have to clone the limbo repository and run: +**Temporarily** you may have to clone the turso repository and run: -`go mod edit -replace github.com/tursodatabase/turso=/path/to/limbo/bindings/go` +`go mod edit -replace github.com/tursodatabase/turso=/path/to/turso/bindings/go` ```go import ( @@ -89,19 +89,19 @@ import ( ) func main() { - conn, err := sql.Open("sqlite3", ":memory:") + conn, err := sql.Open("turso", ":memory:") if err != nil { fmt.Printf("Error: %v\n", err) os.Exit(1) } - sql := "CREATE table go_limbo (foo INTEGER, bar TEXT)" + sql := "CREATE table go_turso (foo INTEGER, bar TEXT)" _ = conn.Exec(sql) - sql = "INSERT INTO go_limbo (foo, bar) values (?, ?)" + sql = "INSERT INTO go_turso (foo, bar) values (?, ?)" stmt, _ := conn.Prepare(sql) defer stmt.Close() - _ = stmt.Exec(42, "limbo") - rows, _ := conn.Query("SELECT * from go_limbo") + _ = stmt.Exec(42, "turso") + rows, _ := conn.Query("SELECT * from go_turso") defer rows.Close() for rows.Next() { var a int diff --git a/bindings/go/build_lib.sh b/bindings/go/build_lib.sh index 26bf07ab0..1b77bfa26 100755 --- a/bindings/go/build_lib.sh +++ b/bindings/go/build_lib.sh @@ -6,12 +6,12 @@ set -e # Accept build type as parameter, default to release BUILD_TYPE=${1:-release} -echo "Building Limbo Go library for current platform (build type: $BUILD_TYPE)..." +echo "Building turso Go library for current platform (build type: $BUILD_TYPE)..." # Determine platform-specific details case "$(uname -s)" in Darwin*) - OUTPUT_NAME="lib_limbo_go.dylib" + OUTPUT_NAME="lib_turso_go.dylib" # Map x86_64 to amd64 for Go compatibility ARCH=$(uname -m) if [ "$ARCH" == "x86_64" ]; then @@ -20,7 +20,7 @@ case "$(uname -s)" in PLATFORM="darwin_${ARCH}" ;; Linux*) - OUTPUT_NAME="lib_limbo_go.so" + OUTPUT_NAME="lib_turso_go.so" # Map x86_64 to amd64 for Go compatibility ARCH=$(uname -m) if [ "$ARCH" == "x86_64" ]; then @@ -29,7 +29,7 @@ case "$(uname -s)" in PLATFORM="linux_${ARCH}" ;; MINGW*|MSYS*|CYGWIN*) - OUTPUT_NAME="lib_limbo_go.dll" + OUTPUT_NAME="lib_turso_go.dll" if [ "$(uname -m)" == "x86_64" ]; then PLATFORM="windows_amd64" else @@ -60,11 +60,11 @@ else fi # Build the library -echo "Running cargo build ${CARGO_ARGS} --package limbo-go" -cargo build ${CARGO_ARGS} --package limbo-go +echo "Running cargo build ${CARGO_ARGS} --package turso-go" +cargo build ${CARGO_ARGS} --package turso-go # Copy to the appropriate directory echo "Copying $OUTPUT_NAME to $OUTPUT_DIR/" cp "../../target/${TARGET_DIR}/$OUTPUT_NAME" "$OUTPUT_DIR/" -echo "Library built successfully for $PLATFORM ($BUILD_TYPE build)" \ No newline at end of file +echo "Library built successfully for $PLATFORM ($BUILD_TYPE build)" diff --git a/bindings/go/connection.go b/bindings/go/connection.go index 27d7fc06a..2d0a1dc7b 100644 --- a/bindings/go/connection.go +++ b/bindings/go/connection.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "context" @@ -16,16 +16,16 @@ func init() { if err != nil { panic(err) } - sql.Register(driverName, &limboDriver{}) + sql.Register(driverName, &tursoDriver{}) } -type limboDriver struct { +type tursoDriver struct { sync.Mutex } var ( libOnce sync.Once - limboLib uintptr + tursoLib uintptr loadErr error dbOpen func(string) uintptr dbClose func(uintptr) uintptr @@ -49,32 +49,32 @@ var ( // Register all the symbols on library load func ensureLibLoaded() error { libOnce.Do(func() { - limboLib, loadErr = loadLibrary() + tursoLib, loadErr = loadLibrary() if loadErr != nil { return } - purego.RegisterLibFunc(&dbOpen, limboLib, FfiDbOpen) - purego.RegisterLibFunc(&dbClose, limboLib, FfiDbClose) - purego.RegisterLibFunc(&connPrepare, limboLib, FfiDbPrepare) - purego.RegisterLibFunc(&connGetError, limboLib, FfiDbGetError) - purego.RegisterLibFunc(&freeBlobFunc, limboLib, FfiFreeBlob) - purego.RegisterLibFunc(&freeStringFunc, limboLib, FfiFreeCString) - purego.RegisterLibFunc(&rowsGetColumns, limboLib, FfiRowsGetColumns) - purego.RegisterLibFunc(&rowsGetColumnName, limboLib, FfiRowsGetColumnName) - purego.RegisterLibFunc(&rowsGetValue, limboLib, FfiRowsGetValue) - purego.RegisterLibFunc(&closeRows, limboLib, FfiRowsClose) - purego.RegisterLibFunc(&rowsNext, limboLib, FfiRowsNext) - purego.RegisterLibFunc(&rowsGetError, limboLib, FfiRowsGetError) - purego.RegisterLibFunc(&stmtQuery, limboLib, FfiStmtQuery) - purego.RegisterLibFunc(&stmtExec, limboLib, FfiStmtExec) - purego.RegisterLibFunc(&stmtParamCount, limboLib, FfiStmtParameterCount) - purego.RegisterLibFunc(&stmtGetError, limboLib, FfiStmtGetError) - purego.RegisterLibFunc(&stmtClose, limboLib, FfiStmtClose) + purego.RegisterLibFunc(&dbOpen, tursoLib, FfiDbOpen) + purego.RegisterLibFunc(&dbClose, tursoLib, FfiDbClose) + purego.RegisterLibFunc(&connPrepare, tursoLib, FfiDbPrepare) + purego.RegisterLibFunc(&connGetError, tursoLib, FfiDbGetError) + purego.RegisterLibFunc(&freeBlobFunc, tursoLib, FfiFreeBlob) + purego.RegisterLibFunc(&freeStringFunc, tursoLib, FfiFreeCString) + purego.RegisterLibFunc(&rowsGetColumns, tursoLib, FfiRowsGetColumns) + purego.RegisterLibFunc(&rowsGetColumnName, tursoLib, FfiRowsGetColumnName) + purego.RegisterLibFunc(&rowsGetValue, tursoLib, FfiRowsGetValue) + purego.RegisterLibFunc(&closeRows, tursoLib, FfiRowsClose) + purego.RegisterLibFunc(&rowsNext, tursoLib, FfiRowsNext) + purego.RegisterLibFunc(&rowsGetError, tursoLib, FfiRowsGetError) + purego.RegisterLibFunc(&stmtQuery, tursoLib, FfiStmtQuery) + purego.RegisterLibFunc(&stmtExec, tursoLib, FfiStmtExec) + purego.RegisterLibFunc(&stmtParamCount, tursoLib, FfiStmtParameterCount) + purego.RegisterLibFunc(&stmtGetError, tursoLib, FfiStmtGetError) + purego.RegisterLibFunc(&stmtClose, tursoLib, FfiStmtClose) }) return loadErr } -func (d *limboDriver) Open(name string) (driver.Conn, error) { +func (d *tursoDriver) Open(name string) (driver.Conn, error) { d.Lock() conn, err := openConn(name) d.Unlock() @@ -84,23 +84,23 @@ func (d *limboDriver) Open(name string) (driver.Conn, error) { return conn, nil } -type limboConn struct { +type tursoConn struct { sync.Mutex ctx uintptr } -func openConn(dsn string) (*limboConn, error) { +func openConn(dsn string) (*tursoConn, error) { ctx := dbOpen(dsn) if ctx == 0 { return nil, fmt.Errorf("failed to open database for dsn=%q", dsn) } - return &limboConn{ + return &tursoConn{ sync.Mutex{}, ctx, }, loadErr } -func (c *limboConn) Close() error { +func (c *tursoConn) Close() error { if c.ctx == 0 { return nil } @@ -111,7 +111,7 @@ func (c *limboConn) Close() error { return nil } -func (c *limboConn) getError() error { +func (c *tursoConn) getError() error { if c.ctx == 0 { return errors.New("connection closed") } @@ -124,7 +124,7 @@ func (c *limboConn) getError() error { return errors.New(cpy) } -func (c *limboConn) Prepare(query string) (driver.Stmt, error) { +func (c *tursoConn) Prepare(query string) (driver.Stmt, error) { if c.ctx == 0 { return nil, errors.New("connection closed") } @@ -137,13 +137,13 @@ func (c *limboConn) Prepare(query string) (driver.Stmt, error) { return newStmt(stmtPtr, query), nil } -// limboTx implements driver.Tx -type limboTx struct { - conn *limboConn +// tursoTx implements driver.Tx +type tursoTx struct { + conn *tursoConn } // Begin starts a new transaction with default isolation level -func (c *limboConn) Begin() (driver.Tx, error) { +func (c *tursoConn) Begin() (driver.Tx, error) { c.Lock() defer c.Unlock() @@ -165,12 +165,12 @@ func (c *limboConn) Begin() (driver.Tx, error) { return nil, err } - return &limboTx{conn: c}, nil + return &tursoTx{conn: c}, nil } // BeginTx starts a transaction with the specified options. // Currently only supports default isolation level and non-read-only transactions. -func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (c *tursoConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { // Skip handling non-default isolation levels and read-only mode // for now, letting database/sql package handle these cases if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) || opts.ReadOnly { @@ -187,7 +187,7 @@ func (c *limboConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. } // Commit commits the transaction -func (tx *limboTx) Commit() error { +func (tx *tursoTx) Commit() error { tx.conn.Lock() defer tx.conn.Unlock() @@ -208,8 +208,7 @@ func (tx *limboTx) Commit() error { } // Rollback aborts the transaction. -// Note: This operation is not currently fully supported by Limbo and will return an error. -func (tx *limboTx) Rollback() error { +func (tx *tursoTx) Rollback() error { tx.conn.Lock() defer tx.conn.Unlock() diff --git a/bindings/go/embedded.go b/bindings/go/embedded.go index 9f04f2d79..2a44795d6 100644 --- a/bindings/go/embedded.go +++ b/bindings/go/embedded.go @@ -1,4 +1,4 @@ -// Go bindings for the Limbo database. +// Go bindings for the turso database. // // This file implements library embedding and extraction at runtime, a pattern // also used in several other Go projects that need to distribute native binaries: @@ -21,7 +21,7 @@ // The embedded library is extracted to a user-specific temporary directory and // loaded dynamically. If extraction fails, the code falls back to the traditional // method of searching system paths. -package limbo +package turso import ( "embed" @@ -52,11 +52,11 @@ func extractEmbeddedLibrary() (string, error) { switch runtime.GOOS { case "darwin": - libName = "lib_limbo_go.dylib" + libName = "lib_turso_go.dylib" case "linux": - libName = "lib_limbo_go.so" + libName = "lib_turso_go.so" case "windows": - libName = "lib_limbo_go.dll" + libName = "lib_turso_go.dll" default: extractErr = fmt.Errorf("unsupported operating system: %s", runtime.GOOS) return @@ -80,7 +80,7 @@ func extractEmbeddedLibrary() (string, error) { platformDir = fmt.Sprintf("%s_%s", runtime.GOOS, archSuffix) // Create a unique temporary directory for the current user - tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("limbo-go-%d", os.Getuid())) + tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("turso-go-%d", os.Getuid())) if err := os.MkdirAll(tempDir, 0755); err != nil { extractErr = fmt.Errorf("failed to create temp directory: %w", err) return diff --git a/bindings/go/rows.go b/bindings/go/rows.go index 1d14e0d0c..c82bd2e65 100644 --- a/bindings/go/rows.go +++ b/bindings/go/rows.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "database/sql/driver" @@ -8,7 +8,7 @@ import ( "sync" ) -type limboRows struct { +type tursoRows struct { mu sync.Mutex ctx uintptr columns []string @@ -16,8 +16,8 @@ type limboRows struct { closed bool } -func newRows(ctx uintptr) *limboRows { - return &limboRows{ +func newRows(ctx uintptr) *tursoRows { + return &tursoRows{ mu: sync.Mutex{}, ctx: ctx, columns: nil, @@ -26,14 +26,14 @@ func newRows(ctx uintptr) *limboRows { } } -func (r *limboRows) isClosed() bool { +func (r *tursoRows) isClosed() bool { if r.ctx == 0 || r.closed { return true } return false } -func (r *limboRows) Columns() []string { +func (r *tursoRows) Columns() []string { if r.isClosed() { return nil } @@ -54,7 +54,7 @@ func (r *limboRows) Columns() []string { return r.columns } -func (r *limboRows) Close() error { +func (r *tursoRows) Close() error { r.err = errors.New(RowsClosedErr) if r.isClosed() { return r.err @@ -67,7 +67,7 @@ func (r *limboRows) Close() error { return nil } -func (r *limboRows) Err() error { +func (r *tursoRows) Err() error { if r.err == nil { r.mu.Lock() defer r.mu.Unlock() @@ -76,7 +76,7 @@ func (r *limboRows) Err() error { return r.err } -func (r *limboRows) Next(dest []driver.Value) error { +func (r *tursoRows) Next(dest []driver.Value) error { r.mu.Lock() defer r.mu.Unlock() if r.isClosed() { @@ -106,7 +106,7 @@ func (r *limboRows) Next(dest []driver.Value) error { } // mutex will already be locked. this is always called after FFI -func (r *limboRows) getError() error { +func (r *tursoRows) getError() error { if r.isClosed() { return r.err } diff --git a/bindings/go/rs_src/lib.rs b/bindings/go/rs_src/lib.rs index 475ad062f..26a2a4dfd 100644 --- a/bindings/go/rs_src/lib.rs +++ b/bindings/go/rs_src/lib.rs @@ -23,19 +23,19 @@ pub unsafe extern "C" fn db_open(path: *const c_char) -> *mut c_void { let Ok((io, conn)) = Connection::from_uri(path, true, false, false) else { panic!("Failed to open connection with path: {path}"); }; - LimboConn::new(conn, io).to_ptr() + TursoConn::new(conn, io).to_ptr() } #[allow(dead_code)] -struct LimboConn { +struct TursoConn { conn: Arc, io: Arc, err: Option, } -impl LimboConn { +impl TursoConn { fn new(conn: Arc, io: Arc) -> Self { - LimboConn { + TursoConn { conn, io, err: None, @@ -47,11 +47,11 @@ impl LimboConn { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'static mut LimboConn { + fn from_ptr(ptr: *mut c_void) -> &'static mut TursoConn { if ptr.is_null() { panic!("Null pointer"); } - unsafe { &mut *(ptr as *mut LimboConn) } + unsafe { &mut *(ptr as *mut TursoConn) } } fn get_error(&mut self) -> *const c_char { @@ -73,7 +73,7 @@ pub extern "C" fn db_get_error(ctx: *mut c_void) -> *const c_char { if ctx.is_null() { return std::ptr::null(); } - let conn = LimboConn::from_ptr(ctx); + let conn = TursoConn::from_ptr(ctx); conn.get_error() } @@ -83,6 +83,6 @@ pub extern "C" fn db_get_error(ctx: *mut c_void) -> *const c_char { #[no_mangle] pub unsafe extern "C" fn db_close(db: *mut c_void) { if !db.is_null() { - let _ = unsafe { Box::from_raw(db as *mut LimboConn) }; + let _ = unsafe { Box::from_raw(db as *mut TursoConn) }; } } diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index d3d64ed3c..8a05440a5 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -1,19 +1,19 @@ use crate::{ - types::{LimboValue, ResultCode}, - LimboConn, + types::{ResultCode, TursoValue}, + TursoConn, }; use std::ffi::{c_char, c_void}; use turso_core::{LimboError, Statement, StepResult, Value}; -pub struct LimboRows<'conn> { +pub struct TursoRows<'conn> { stmt: Box, - _conn: &'conn mut LimboConn, + _conn: &'conn mut TursoConn, err: Option, } -impl<'conn> LimboRows<'conn> { - pub fn new(stmt: Statement, conn: &'conn mut LimboConn) -> Self { - LimboRows { +impl<'conn> TursoRows<'conn> { + pub fn new(stmt: Statement, conn: &'conn mut TursoConn) -> Self { + TursoRows { stmt: Box::new(stmt), _conn: conn, err: None, @@ -25,11 +25,11 @@ impl<'conn> LimboRows<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - pub fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboRows<'conn> { + pub fn from_ptr(ptr: *mut c_void) -> &'conn mut TursoRows<'conn> { if ptr.is_null() { panic!("Null pointer"); } - unsafe { &mut *(ptr as *mut LimboRows) } + unsafe { &mut *(ptr as *mut TursoRows) } } fn get_error(&mut self) -> *const c_char { @@ -49,7 +49,7 @@ pub extern "C" fn rows_next(ctx: *mut c_void) -> ResultCode { if ctx.is_null() { return ResultCode::Error; } - let ctx = LimboRows::from_ptr(ctx); + let ctx = TursoRows::from_ptr(ctx); match ctx.stmt.step() { Ok(StepResult::Row) => ResultCode::Row, @@ -76,11 +76,11 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v if ctx.is_null() { return std::ptr::null(); } - let ctx = LimboRows::from_ptr(ctx); + let ctx = TursoRows::from_ptr(ctx); if let Some(row) = ctx.stmt.row() { if let Ok(value) = row.get::<&Value>(col_idx) { - return LimboValue::from_db_value(value).to_ptr(); + return TursoValue::from_db_value(value).to_ptr(); } } std::ptr::null() @@ -101,7 +101,7 @@ pub extern "C" fn rows_get_columns(rows_ptr: *mut c_void) -> i32 { if rows_ptr.is_null() { return -1; } - let rows = LimboRows::from_ptr(rows_ptr); + let rows = TursoRows::from_ptr(rows_ptr); rows.stmt.num_columns() as i32 } @@ -113,7 +113,7 @@ pub extern "C" fn rows_get_column_name(rows_ptr: *mut c_void, idx: i32) -> *cons if rows_ptr.is_null() { return std::ptr::null_mut(); } - let rows = LimboRows::from_ptr(rows_ptr); + let rows = TursoRows::from_ptr(rows_ptr); if idx < 0 || idx as usize >= rows.stmt.num_columns() { return std::ptr::null_mut(); } @@ -127,18 +127,18 @@ pub extern "C" fn rows_get_error(ctx: *mut c_void) -> *const c_char { if ctx.is_null() { return std::ptr::null(); } - let ctx = LimboRows::from_ptr(ctx); + let ctx = TursoRows::from_ptr(ctx); ctx.get_error() } #[no_mangle] pub extern "C" fn rows_close(ctx: *mut c_void) { if !ctx.is_null() { - let rows = LimboRows::from_ptr(ctx); + let rows = TursoRows::from_ptr(ctx); rows.stmt.reset(); rows.err = None; } unsafe { - let _ = Box::from_raw(ctx.cast::()); + let _ = Box::from_raw(ctx.cast::()); } } diff --git a/bindings/go/rs_src/statement.rs b/bindings/go/rs_src/statement.rs index 4dc115aec..65859161d 100644 --- a/bindings/go/rs_src/statement.rs +++ b/bindings/go/rs_src/statement.rs @@ -1,6 +1,6 @@ -use crate::rows::LimboRows; -use crate::types::{AllocPool, LimboValue, ResultCode}; -use crate::LimboConn; +use crate::rows::TursoRows; +use crate::types::{AllocPool, ResultCode, TursoValue}; +use crate::TursoConn; use std::ffi::{c_char, c_void}; use std::num::NonZero; use turso_core::{LimboError, Statement, StepResult}; @@ -12,10 +12,10 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v } let query_str = unsafe { std::ffi::CStr::from_ptr(query) }.to_str().unwrap(); - let db = LimboConn::from_ptr(ctx); + let db = TursoConn::from_ptr(ctx); let stmt = db.conn.prepare(query_str); match stmt { - Ok(stmt) => LimboStatement::new(Some(stmt), db).to_ptr(), + Ok(stmt) => TursoStatement::new(Some(stmt), db).to_ptr(), Err(err) => { db.err = Some(err); std::ptr::null_mut() @@ -26,14 +26,14 @@ pub extern "C" fn db_prepare(ctx: *mut c_void, query: *const c_char) -> *mut c_v #[no_mangle] pub extern "C" fn stmt_execute( ctx: *mut c_void, - args_ptr: *mut LimboValue, + args_ptr: *mut TursoValue, arg_count: usize, changes: *mut i64, ) -> ResultCode { if ctx.is_null() { return ResultCode::Error; } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); let args = if !args_ptr.is_null() && arg_count > 0 { unsafe { std::slice::from_raw_parts(args_ptr, arg_count) } @@ -88,7 +88,7 @@ pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { if ctx.is_null() { return -1; } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); let Some(statement) = stmt.statement.as_ref() else { stmt.err = Some(LimboError::InternalError("Statement is closed".to_string())); return -1; @@ -99,13 +99,13 @@ pub extern "C" fn stmt_parameter_count(ctx: *mut c_void) -> i32 { #[no_mangle] pub extern "C" fn stmt_query( ctx: *mut c_void, - args_ptr: *mut LimboValue, + args_ptr: *mut TursoValue, args_count: usize, ) -> *mut c_void { if ctx.is_null() { return std::ptr::null_mut(); } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); let args = if !args_ptr.is_null() && args_count > 0 { unsafe { std::slice::from_raw_parts(args_ptr, args_count) } } else { @@ -119,21 +119,21 @@ pub extern "C" fn stmt_query( let val = arg.to_value(&mut pool); statement.bind_at(NonZero::new(i + 1).unwrap(), val); } - // ownership of the statement is transferred to the LimboRows object. - LimboRows::new(statement, stmt.conn).to_ptr() + // ownership of the statement is transferred to the TursoRows object. + TursoRows::new(statement, stmt.conn).to_ptr() } -pub struct LimboStatement<'conn> { - /// If 'query' is ran on the statement, ownership is transferred to the LimboRows object +pub struct TursoStatement<'conn> { + /// If 'query' is ran on the statement, ownership is transferred to the TursoRows object pub statement: Option, - pub conn: &'conn mut LimboConn, + pub conn: &'conn mut TursoConn, pub err: Option, } #[no_mangle] pub extern "C" fn stmt_close(ctx: *mut c_void) -> ResultCode { if !ctx.is_null() { - let stmt = unsafe { Box::from_raw(ctx as *mut LimboStatement) }; + let stmt = unsafe { Box::from_raw(ctx as *mut TursoStatement) }; drop(stmt); return ResultCode::Ok; } @@ -145,13 +145,13 @@ pub extern "C" fn stmt_get_error(ctx: *mut c_void) -> *const c_char { if ctx.is_null() { return std::ptr::null(); } - let stmt = LimboStatement::from_ptr(ctx); + let stmt = TursoStatement::from_ptr(ctx); stmt.get_error() } -impl<'conn> LimboStatement<'conn> { - pub fn new(statement: Option, conn: &'conn mut LimboConn) -> Self { - LimboStatement { +impl<'conn> TursoStatement<'conn> { + pub fn new(statement: Option, conn: &'conn mut TursoConn) -> Self { + TursoStatement { statement, conn, err: None, @@ -163,11 +163,11 @@ impl<'conn> LimboStatement<'conn> { Box::into_raw(Box::new(self)) as *mut c_void } - fn from_ptr(ptr: *mut c_void) -> &'conn mut LimboStatement<'conn> { + fn from_ptr(ptr: *mut c_void) -> &'conn mut TursoStatement<'conn> { if ptr.is_null() { panic!("Null pointer"); } - unsafe { &mut *(ptr as *mut LimboStatement) } + unsafe { &mut *(ptr as *mut TursoStatement) } } fn get_error(&mut self) -> *const c_char { diff --git a/bindings/go/rs_src/types.rs b/bindings/go/rs_src/types.rs index 683cfde3f..9ec06b3bf 100644 --- a/bindings/go/rs_src/types.rs +++ b/bindings/go/rs_src/types.rs @@ -34,33 +34,33 @@ pub enum ValueType { } #[repr(C)] -pub struct LimboValue { +pub struct TursoValue { value_type: ValueType, value: ValueUnion, } -impl Debug for LimboValue { +impl Debug for TursoValue { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.value_type { ValueType::Integer => { let i = self.value.to_int(); - f.debug_struct("LimboValue").field("value", &i).finish() + f.debug_struct("TursoValue").field("value", &i).finish() } ValueType::Real => { let r = self.value.to_real(); - f.debug_struct("LimboValue").field("value", &r).finish() + f.debug_struct("TursoValue").field("value", &r).finish() } ValueType::Text => { let t = self.value.to_str(); - f.debug_struct("LimboValue").field("value", &t).finish() + f.debug_struct("TursoValue").field("value", &t).finish() } ValueType::Blob => { let blob = self.value.to_bytes(); - f.debug_struct("LimboValue") + f.debug_struct("TursoValue") .field("value", &blob.to_vec()) .finish() } ValueType::Null => f - .debug_struct("LimboValue") + .debug_struct("TursoValue") .field("value", &"NULL") .finish(), } @@ -164,9 +164,9 @@ impl ValueUnion { } } -impl LimboValue { +impl TursoValue { fn new(value_type: ValueType, value: ValueUnion) -> Self { - LimboValue { value_type, value } + TursoValue { value_type, value } } #[allow(clippy::wrong_self_convention)] @@ -177,18 +177,18 @@ impl LimboValue { pub fn from_db_value(value: &turso_core::Value) -> Self { match value { turso_core::Value::Integer(i) => { - LimboValue::new(ValueType::Integer, ValueUnion::from_int(*i)) + TursoValue::new(ValueType::Integer, ValueUnion::from_int(*i)) } turso_core::Value::Float(r) => { - LimboValue::new(ValueType::Real, ValueUnion::from_real(*r)) + TursoValue::new(ValueType::Real, ValueUnion::from_real(*r)) } turso_core::Value::Text(s) => { - LimboValue::new(ValueType::Text, ValueUnion::from_str(s.as_str())) + TursoValue::new(ValueType::Text, ValueUnion::from_str(s.as_str())) } turso_core::Value::Blob(b) => { - LimboValue::new(ValueType::Blob, ValueUnion::from_bytes(b.as_slice())) + TursoValue::new(ValueType::Blob, ValueUnion::from_bytes(b.as_slice())) } - turso_core::Value::Null => LimboValue::new(ValueType::Null, ValueUnion::from_null()), + turso_core::Value::Null => TursoValue::new(ValueType::Null, ValueUnion::from_null()), } } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 9e045175e..c12ae9d71 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "context" @@ -9,22 +9,22 @@ import ( "unsafe" ) -type limboStmt struct { +type tursoStmt struct { mu sync.Mutex ctx uintptr sql string err error } -func newStmt(ctx uintptr, sql string) *limboStmt { - return &limboStmt{ +func newStmt(ctx uintptr, sql string) *tursoStmt { + return &tursoStmt{ ctx: uintptr(ctx), sql: sql, err: nil, } } -func (ls *limboStmt) NumInput() int { +func (ls *tursoStmt) NumInput() int { ls.mu.Lock() defer ls.mu.Unlock() res := int(stmtParamCount(ls.ctx)) @@ -35,7 +35,7 @@ func (ls *limboStmt) NumInput() int { return res } -func (ls *limboStmt) Close() error { +func (ls *tursoStmt) Close() error { ls.mu.Lock() defer ls.mu.Unlock() if ls.ctx == 0 { @@ -49,7 +49,7 @@ func (ls *limboStmt) Close() error { return nil } -func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { +func (ls *tursoStmt) Exec(args []driver.Value) (driver.Result, error) { argArray, cleanup, err := buildArgs(args) defer cleanup() if err != nil { @@ -80,7 +80,7 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } } -func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { +func (ls *tursoStmt) Query(args []driver.Value) (driver.Rows, error) { queryArgs, cleanup, err := buildArgs(args) defer cleanup() if err != nil { @@ -99,7 +99,7 @@ func (ls *limboStmt) Query(args []driver.Value) (driver.Rows, error) { return newRows(rowsPtr), nil } -func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (ls *tursoStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { stripped := namedValueToValue(args) argArray, cleanup, err := getArgsPtr(stripped) defer cleanup() @@ -129,7 +129,7 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive } } -func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (ls *tursoStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { queryArgs, allocs, err := buildNamedArgs(args) defer allocs() if err != nil { @@ -154,7 +154,7 @@ func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) } } -func (ls *limboStmt) Err() error { +func (ls *tursoStmt) Err() error { if ls.err == nil { ls.mu.Lock() defer ls.mu.Unlock() @@ -164,7 +164,7 @@ func (ls *limboStmt) Err() error { } // mutex should always be locked when calling - always called after FFI -func (ls *limboStmt) getError() error { +func (ls *tursoStmt) getError() error { err := stmtGetError(ls.ctx) if err == 0 { return nil diff --git a/bindings/go/limbo_test.go b/bindings/go/turso_test.go similarity index 96% rename from bindings/go/limbo_test.go rename to bindings/go/turso_test.go index 8fe36ae17..ff2bc90a4 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/turso_test.go @@ -1,4 +1,4 @@ -package limbo_test +package turso_test import ( "database/sql" @@ -17,7 +17,7 @@ var ( ) func TestMain(m *testing.M) { - conn, connErr = sql.Open("sqlite3", ":memory:") + conn, connErr = sql.Open("turso", ":memory:") if connErr != nil { panic(connErr) } @@ -146,7 +146,7 @@ func TestFunctions(t *testing.T) { } func TestDuplicateConnection(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -177,7 +177,7 @@ func TestDuplicateConnection(t *testing.T) { } func TestDuplicateConnection2(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -209,7 +209,7 @@ func TestDuplicateConnection2(t *testing.T) { } func TestConnectionError(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -228,7 +228,7 @@ func TestConnectionError(t *testing.T) { } func TestStatementError(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -250,7 +250,7 @@ func TestStatementError(t *testing.T) { } func TestDriverRowsErrorMessages(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("failed to open database: %v", err) } @@ -285,7 +285,7 @@ func TestDriverRowsErrorMessages(t *testing.T) { func TestTransaction(t *testing.T) { // Open database connection - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening database: %v", err) } @@ -359,7 +359,7 @@ func TestTransaction(t *testing.T) { } func TestVectorOperations(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -397,7 +397,7 @@ func TestVectorOperations(t *testing.T) { } func TestSQLFeatures(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -501,7 +501,7 @@ func TestSQLFeatures(t *testing.T) { } func TestDateTimeFunctions(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -536,7 +536,7 @@ func TestDateTimeFunctions(t *testing.T) { } func TestMathFunctions(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -572,7 +572,7 @@ func TestMathFunctions(t *testing.T) { } func TestJSONFunctions(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + db, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening connection: %v", err) } @@ -610,7 +610,7 @@ func TestJSONFunctions(t *testing.T) { } func TestParameterOrdering(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } @@ -685,7 +685,7 @@ func TestParameterOrdering(t *testing.T) { } func TestIndex(t *testing.T) { - newConn, err := sql.Open("sqlite3", ":memory:") + newConn, err := sql.Open("turso", ":memory:") if err != nil { t.Fatalf("Error opening new connection: %v", err) } diff --git a/bindings/go/limbo_unix.go b/bindings/go/turso_unix.go similarity index 99% rename from bindings/go/limbo_unix.go rename to bindings/go/turso_unix.go index 1dd51f42e..3e61f278f 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/turso_unix.go @@ -1,6 +1,6 @@ //go:build linux || darwin -package limbo +package turso import ( "fmt" diff --git a/bindings/go/limbo_windows.go b/bindings/go/turso_windows.go similarity index 98% rename from bindings/go/limbo_windows.go rename to bindings/go/turso_windows.go index 2fddfd9a4..3926bedc9 100644 --- a/bindings/go/limbo_windows.go +++ b/bindings/go/turso_windows.go @@ -1,6 +1,6 @@ //go:build windows -package limbo +package turso import ( "fmt" diff --git a/bindings/go/types.go b/bindings/go/types.go index f35899828..1c9f00d62 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -1,4 +1,4 @@ -package limbo +package turso import ( "database/sql/driver" @@ -66,8 +66,8 @@ func (rc ResultCode) String() string { } const ( - driverName = "sqlite3" - libName = "lib_limbo_go" + driverName = "turso" + libName = "lib_turso_go" RowsClosedErr = "sql: Rows closed" FfiDbOpen = "db_open" FfiDbClose = "db_close" @@ -98,7 +98,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value { return out } -func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) { +func buildNamedArgs(named []driver.NamedValue) ([]tursoValue, func(), error) { args := namedValueToValue(named) return buildArgs(args) } @@ -131,7 +131,7 @@ func (vt valueType) String() string { } // struct to pass Go values over FFI -type limboValue struct { +type tursoValue struct { Type valueType _ [4]byte Value [8]byte @@ -143,12 +143,12 @@ type Blob struct { Len int64 } -// convert a limboValue to a native Go value +// convert a tursoValue to a native Go value func toGoValue(valPtr uintptr) interface{} { if valPtr == 0 { return nil } - val := (*limboValue)(unsafe.Pointer(valPtr)) + val := (*tursoValue)(unsafe.Pointer(valPtr)) switch val.Type { case intVal: return *(*int64)(unsafe.Pointer(&val.Value)) @@ -228,50 +228,50 @@ func freeCString(cstrPtr uintptr) { freeStringFunc(cstrPtr) } -// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI +// convert a Go slice of driver.Value to a slice of tursoValue that can be sent over FFI // for Blob types, we have to pin them so they are not garbage collected before they can be copied // into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call -func buildArgs(args []driver.Value) ([]limboValue, func(), error) { +func buildArgs(args []driver.Value) ([]tursoValue, func(), error) { pinner := new(runtime.Pinner) - argSlice := make([]limboValue, len(args)) + argSlice := make([]tursoValue, len(args)) for i, v := range args { - limboVal := limboValue{} + tursoVal := tursoValue{} switch val := v.(type) { case nil: - limboVal.Type = nullVal + tursoVal.Type = nullVal case int64: - limboVal.Type = intVal - limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) + tursoVal.Type = intVal + tursoVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case float64: - limboVal.Type = realVal - limboVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) + tursoVal.Type = realVal + tursoVal.Value = *(*[8]byte)(unsafe.Pointer(&val)) case bool: - limboVal.Type = intVal + tursoVal.Type = intVal boolAsInt := int64(0) if val { boolAsInt = 1 } - limboVal.Value = *(*[8]byte)(unsafe.Pointer(&boolAsInt)) + tursoVal.Value = *(*[8]byte)(unsafe.Pointer(&boolAsInt)) case string: - limboVal.Type = textVal + tursoVal.Type = textVal cstr := CString(val) pinner.Pin(cstr) - *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) + *(*uintptr)(unsafe.Pointer(&tursoVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: - limboVal.Type = blobVal + tursoVal.Type = blobVal blob := makeBlob(val) pinner.Pin(blob) - *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob)) + *(*uintptr)(unsafe.Pointer(&tursoVal.Value)) = uintptr(unsafe.Pointer(blob)) case time.Time: - limboVal.Type = textVal + tursoVal.Type = textVal timeStr := val.Format(time.RFC3339) cstr := CString(timeStr) pinner.Pin(cstr) - *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) + *(*uintptr)(unsafe.Pointer(&tursoVal.Value)) = uintptr(unsafe.Pointer(cstr)) default: return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v) } - argSlice[i] = limboVal + argSlice[i] = tursoVal } return argSlice, pinner.Unpin, nil } From 2614a42294028f6896aaf257ad2913d74bd80ce1 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Tue, 26 Aug 2025 14:18:57 -0400 Subject: [PATCH 73/73] Update package name in go CI --- .github/workflows/go.yml | 3 +-- README.md | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index acf74e39a..be3b1a84b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -34,10 +34,9 @@ jobs: go-version: "1.23" - name: build Go bindings library - run: cargo build --package limbo-go + run: cargo build --package turso-go - name: run Go tests env: LD_LIBRARY_PATH: ${{ github.workspace }}/target/debug:$LD_LIBRARY_PATH run: go test - diff --git a/README.md b/README.md index a98999ee5..ceaca572c 100644 --- a/README.md +++ b/README.md @@ -174,8 +174,8 @@ print(res.fetchone()) 1. Clone the repository 2. Build the library and set your LD_LIBRARY_PATH to include turso's target directory ```console -cargo build --package limbo-go -export LD_LIBRARY_PATH=/path/to/limbo/target/debug:$LD_LIBRARY_PATH +cargo build --package turso-go +export LD_LIBRARY_PATH=/path/to/turso/target/debug:$LD_LIBRARY_PATH ``` 3. Use the driver @@ -191,7 +191,7 @@ import ( _ "github.com/tursodatabase/turso" ) -conn, _ = sql.Open("sqlite3", "sqlite.db") +conn, _ = sql.Open("turso", "sqlite.db") defer conn.Close() stmt, _ := conn.Prepare("select * from users")