diff --git a/Cargo.lock b/Cargo.lock index 9e9739689..a331618e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2137,6 +2137,14 @@ dependencies = [ "turso_ext", ] +[[package]] +name = "limbo_fuzzy" +version = "0.2.0-pre.7" +dependencies = [ + "mimalloc", + "turso_ext", +] + [[package]] name = "limbo_ipaddr" version = "0.2.0-pre.7" diff --git a/Cargo.toml b/Cargo.toml index 6f3c64d9d..7ffeebf3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "extensions/percentile", "extensions/regexp", "extensions/tests", + "extensions/fuzzy", "macros", "simulator", "sqlite3", @@ -61,6 +62,7 @@ limbo_percentile = { path = "extensions/percentile", version = "0.2.0-pre.7" } limbo_regexp = { path = "extensions/regexp", version = "0.2.0-pre.7" } turso_sqlite3_parser = { path = "vendored/sqlite3-parser", version = "0.2.0-pre.7" } limbo_uuid = { path = "extensions/uuid", version = "0.2.0-pre.7" } +limbo_fuzzy = { path = "extensions/fuzzy", version = "0.2.0-pre.7" } turso_parser = { path = "parser", version = "0.2.0-pre.7" } sql_generation = { path = "sql_generation" } strum = { version = "0.26", features = ["derive"] } diff --git a/extensions/fuzzy/Cargo.toml b/extensions/fuzzy/Cargo.toml new file mode 100644 index 000000000..cf6036cfd --- /dev/null +++ b/extensions/fuzzy/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "limbo_fuzzy" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Limbo fuzzy string extension" + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +static = ["turso_ext/static"] + +[dependencies] +turso_ext = { workspace = true, features = ["static"] } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +mimalloc = { version = "0.1", default-features = false } diff --git a/extensions/fuzzy/build.rs b/extensions/fuzzy/build.rs new file mode 100644 index 000000000..4a3d51d14 --- /dev/null +++ b/extensions/fuzzy/build.rs @@ -0,0 +1,5 @@ +fn main() { + if cfg!(target_os = "windows") { + println!("cargo:rustc-link-lib=advapi32"); + } +} diff --git a/extensions/fuzzy/src/common.rs b/extensions/fuzzy/src/common.rs new file mode 100644 index 000000000..4b0c12fd1 --- /dev/null +++ b/extensions/fuzzy/src/common.rs @@ -0,0 +1,32 @@ +pub const CCLASS_SILENT: u8 = 0; +pub const CCLASS_VOWEL: u8 = 1; +pub const CCLASS_B: u8 = 2; +pub const CCLASS_Y: u8 = 9; +//This will be useful in the phonetic +//pub const CCLASS_L: u8 = 6; +//pub const CCLASS_R: u8 = 7; +//pub const CCLASS_M: u8 = 8; +//pub const CCLASS_DIGIT: u8 = 10; +//pub const CCLASS_SPACE: u8 = 11; +//pub const CCLASS_OTHER: u8 = 12; +pub const MID_CLASS: [u8; 128] = [ + 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 12, 12, 11, 11, 12, 12, // + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, // + 12, 12, 12, 12, 12, 12, 0, 12, 12, 12, 12, 12, 12, 12, 12, 12, // + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 12, // + 12, 1, 2, 3, 4, 1, 2, 3, 0, 1, 3, 3, 6, 8, 8, 1, // + 2, 3, 7, 3, 4, 1, 2, 2, 3, 1, 3, 12, 12, 12, 12, 12, // + 12, 1, 2, 3, 4, 1, 2, 3, 0, 1, 3, 3, 6, 8, 8, 1, // + 2, 3, 7, 3, 4, 1, 2, 2, 3, 1, 3, 12, 12, 12, 12, 12, // +]; + +pub const INIT_CLASS: [u8; 128] = [ + 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 12, 12, 11, 11, 12, 12, // + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, // + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, // + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 12, // + 12, 1, 2, 3, 4, 1, 2, 3, 0, 1, 3, 3, 6, 8, 8, 1, // + 2, 3, 7, 3, 4, 1, 2, 2, 3, 9, 3, 12, 12, 12, 12, 12, // + 12, 1, 2, 3, 4, 1, 2, 3, 0, 1, 3, 3, 6, 8, 8, 1, // + 2, 3, 7, 3, 4, 1, 2, 2, 3, 9, 3, 12, 12, 12, 12, 12, // +]; diff --git a/extensions/fuzzy/src/editdist.rs b/extensions/fuzzy/src/editdist.rs new file mode 100644 index 000000000..8821cfbda --- /dev/null +++ b/extensions/fuzzy/src/editdist.rs @@ -0,0 +1,276 @@ +// Adapted from SQLite spellfix.c extension and sqlean fuzzy/editdist.c +use crate::common::*; + +#[derive(Debug, PartialEq)] +pub enum EditDistanceError { + NonAsciiInput, +} + +pub type EditDistanceResult = Result; + +fn character_class(c_prev: u8, c: u8) -> u8 { + if c_prev == 0 { + INIT_CLASS[(c & 0x7f) as usize] + } else { + MID_CLASS[(c & 0x7f) as usize] + } +} + +/// Return the cost of inserting or deleting character c immediately +/// following character c_prev. If c_prev == 0, that means c is the first +/// character of the word. +fn insert_or_delete_cost(c_prev: u8, c: u8, c_next: u8) -> i32 { + let class_c = character_class(c_prev, c); + + if class_c == CCLASS_SILENT { + return 1; + } + + if c_prev == c { + return 10; + } + + if class_c == CCLASS_VOWEL && (c_prev == b'r' || c_next == b'r') { + return 20; // Insert a vowel before or after 'r' + } + + let class_c_prev = character_class(c_prev, c_prev); + if class_c == class_c_prev { + if class_c == CCLASS_VOWEL { + 15 + } else { + 50 + } + } else { + // Any other character insertion or deletion + 100 + } +} + +const FINAL_INS_COST_DIV: i32 = 4; + +/// Return the cost of substituting c_to in place of c_from assuming +/// the previous character is c_prev. If c_prev == 0 then c_to is the first +/// character of the word. +fn substitute_cost(c_prev: u8, c_from: u8, c_to: u8) -> i32 { + if c_from == c_to { + return 0; + } + + if c_from == (c_to ^ 0x20) && c_to.is_ascii_alphabetic() { + return 0; + } + + let class_from = character_class(c_prev, c_from); + let class_to = character_class(c_prev, c_to); + + if class_from == class_to { + 40 + } else if (CCLASS_B..=CCLASS_Y).contains(&class_from) + && (CCLASS_B..=CCLASS_Y).contains(&class_to) + { + 75 + } else { + 100 + } +} + +/// Given two strings z_a and z_b which are pure ASCII, return the cost +/// of transforming z_a into z_b. If z_a ends with '*' assume that it is +/// a prefix of z_b and give only minimal penalty for extra characters +/// on the end of z_b. +/// +/// Returns cost where smaller numbers mean a closer match +/// +/// Returns Err for Non-ASCII characters on input +pub fn edit_distance(z_a: &str, z_b: &str) -> EditDistanceResult { + if z_a.is_empty() && z_b.is_empty() { + return Ok(0); + } + + let za_bytes = z_a.as_bytes(); + let zb_bytes = z_b.as_bytes(); + + if !z_a.is_ascii() || !z_b.is_ascii() { + return Err(EditDistanceError::NonAsciiInput); + } + + if z_a.is_empty() { + let mut res = 0; + let mut c_b_prev = 0u8; + let zb_bytes = z_b.as_bytes(); + + for (i, &c_b) in zb_bytes.iter().enumerate() { + let c_b_next = if i + 1 < zb_bytes.len() { + zb_bytes[i + 1] + } else { + 0 + }; + res += insert_or_delete_cost(c_b_prev, c_b, c_b_next) / FINAL_INS_COST_DIV; + c_b_prev = c_b; + } + return Ok(res); + } + + if z_b.is_empty() { + let mut res = 0; + let mut c_a_prev = 0u8; + let za_bytes = z_a.as_bytes(); + + for (i, &c_a) in za_bytes.iter().enumerate() { + let c_a_next = if i + 1 < za_bytes.len() { + za_bytes[i + 1] + } else { + 0 + }; + res += insert_or_delete_cost(c_a_prev, c_a, c_a_next); + c_a_prev = c_a; + } + return Ok(res); + } + + let mut za_start = 0; + let mut zb_start = 0; + + // Skip any common prefix + while za_start < za_bytes.len() + && zb_start < zb_bytes.len() + && za_bytes[za_start] == zb_bytes[zb_start] + { + za_start += 1; + zb_start += 1; + } + + // If both strings are exhausted after common prefix + if za_start >= za_bytes.len() && zb_start >= zb_bytes.len() { + return Ok(0); + } + + let za_remaining = &za_bytes[za_start..]; + let zb_remaining = &zb_bytes[zb_start..]; + let n_a = za_remaining.len(); + let n_b = zb_remaining.len(); + + // Special processing if either remaining string is empty after prefix matching + if n_a == 0 { + let mut res = 0; + let mut c_b_prev = if za_start > 0 { + za_bytes[za_start - 1] + } else { + 0 + }; + + for (i, &c_b) in zb_remaining.iter().enumerate() { + let c_b_next = if i + 1 < n_b { zb_remaining[i + 1] } else { 0 }; + res += insert_or_delete_cost(c_b_prev, c_b, c_b_next) / FINAL_INS_COST_DIV; + c_b_prev = c_b; + } + return Ok(res); + } + + if n_b == 0 { + let mut res = 0; + let mut c_a_prev = if za_start > 0 { + za_bytes[za_start - 1] + } else { + 0 + }; + + for (i, &c_a) in za_remaining.iter().enumerate() { + let c_a_next = if i + 1 < n_a { za_remaining[i + 1] } else { 0 }; + res += insert_or_delete_cost(c_a_prev, c_a, c_a_next); + c_a_prev = c_a; + } + return Ok(res); + } + + // Check if a is a prefix pattern + if za_remaining.len() == 1 && za_remaining[0] == b'*' { + return Ok(0); + } + + let mut m = vec![0i32; n_b + 1]; + let mut cx = vec![0u8; n_b + 1]; + + let dc = if za_start > 0 { + za_bytes[za_start - 1] + } else { + 0 + }; + m[0] = 0; + cx[0] = dc; + + let mut c_b_prev = dc; + for x_b in 1..=n_b { + let c_b = zb_remaining[x_b - 1]; + let c_b_next = if x_b < n_b { zb_remaining[x_b] } else { 0 }; + cx[x_b] = c_b; + m[x_b] = m[x_b - 1] + insert_or_delete_cost(c_b_prev, c_b, c_b_next); + c_b_prev = c_b; + } + + let mut c_a_prev = dc; + for x_a in 1..=n_a { + let last_a = x_a == n_a; + let c_a = za_remaining[x_a - 1]; + let c_a_next = if x_a < n_a { za_remaining[x_a] } else { 0 }; + + if c_a == b'*' && last_a { + break; + } + + let mut d = m[0]; + m[0] = d + insert_or_delete_cost(c_a_prev, c_a, c_a_next); + + for x_b in 1..=n_b { + let c_b = zb_remaining[x_b - 1]; + let c_b_next = if x_b < n_b { zb_remaining[x_b] } else { 0 }; + + // Cost to insert c_b + let mut ins_cost = insert_or_delete_cost(cx[x_b - 1], c_b, c_b_next); + if last_a { + ins_cost /= FINAL_INS_COST_DIV; + } + + // Cost to delete c_a + let del_cost = insert_or_delete_cost(cx[x_b], c_a, c_b_next); + + // Cost to substitute c_a -> c_b + let sub_cost = substitute_cost(cx[x_b - 1], c_a, c_b); + + // Find best cost + let mut total_cost = ins_cost + m[x_b - 1]; + let mut ncx = c_b; + + if del_cost + m[x_b] < total_cost { + total_cost = del_cost + m[x_b]; + ncx = c_a; + } + + if sub_cost + d < total_cost { + total_cost = sub_cost + d; + } + + d = m[x_b]; + m[x_b] = total_cost; + cx[x_b] = ncx; + } + c_a_prev = c_a; + } + + let res = if za_remaining.last() == Some(&b'*') { + let mut min_cost = m[1]; + + for &val in m.iter().skip(1).take(n_b) { + if val < min_cost { + min_cost = val; + } + } + + min_cost + } else { + m[n_b] + }; + + Ok(res) +} diff --git a/extensions/fuzzy/src/lib.rs b/extensions/fuzzy/src/lib.rs new file mode 100644 index 000000000..fb578e2d6 --- /dev/null +++ b/extensions/fuzzy/src/lib.rs @@ -0,0 +1,522 @@ +// Adapted from sqlean fuzzy +use std::cmp; +use turso_ext::{register_extension, scalar, ResultCode, Value}; +mod common; +mod editdist; + +register_extension! { + scalars: {levenshtein, damerau_levenshtein, edit_distance, hamming, jaronwin, osadist}, +} + +/// Calculates and returns the Levenshtein distance of two non NULL strings. +#[scalar(name = "fuzzy_leven")] +fn levenshtein(args: &[Value]) -> Value { + let Some(arg1) = args[0].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let Some(arg2) = args[1].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let dist = leven(arg1, arg2); + return Value::from_integer(dist); +} + +fn leven(s1: &str, s2: &str) -> i64 { + let mut str1: &[u8] = s1.as_bytes(); + let mut str2: &[u8] = s2.as_bytes(); + let mut str1_len = str1.len(); + let mut str2_len = str2.len(); + + if str1_len == 0 { + return str2_len as i64; + } + + if str2_len == 0 { + return str1_len as i64; + } + + while str1_len > 0 && str2_len > 0 && str1[0] == str2[0] { + str1 = &str1[1..]; + str2 = &str2[1..]; + str1_len -= 1; + str2_len -= 1; + } + + let mut vector: Vec = (0..=str1_len).collect(); + + let mut last_diag: usize; + let mut cur: usize; + + for row in 1..=str2_len { + last_diag = row - 1; + vector[0] = row; + + for col in 1..=str1_len { + cur = vector[col]; + + let cost = if str1[col - 1] == str2[row - 1] { 0 } else { 1 }; + + vector[col] = std::cmp::min( + std::cmp::min(vector[col] + 1, vector[col - 1] + 1), + last_diag + cost, + ); + + last_diag = cur; + } + } + vector[str1_len] as i64 +} + +/// Calculates and returns the Damerau-Levenshtein distance of two non NULL +#[scalar(name = "fuzzy_damlev")] +fn damerau_levenshtein(args: &[Value]) -> Value { + let Some(arg1) = args[0].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let Some(arg2) = args[1].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let dist = damlev(arg1, arg2); + return Value::from_integer(dist); +} + +#[allow(clippy::needless_range_loop)] +fn damlev(s1: &str, s2: &str) -> i64 { + let str1: &[u8] = s1.as_bytes(); + let str2: &[u8] = s2.as_bytes(); + let str1_len = str1.len(); + let str2_len = str2.len(); + + if str1_len == 0 { + return str2_len as i64; + } + + if str2_len == 0 { + return str1_len as i64; + } + + let mut start = 0; + while start < str1_len && start < str2_len && str1[start] == str2[start] { + start += 1; + } + let str1 = &str1[start..]; + let str2 = &str2[start..]; + let len1 = str1.len(); + let len2 = str2.len(); + + const ALPHA_SIZE: usize = 255; + let infi = len1 + len2; + + let mut dict = vec![0usize; ALPHA_SIZE]; + + let rows = len1 + 2; + let cols = len2 + 2; + let mut matrix = vec![vec![0usize; cols]; rows]; + + matrix[0][0] = infi; + + for i in 1..rows { + matrix[i][0] = infi; + matrix[i][1] = i - 1; + } + for j in 1..cols { + matrix[0][j] = infi; + matrix[1][j] = j - 1; + } + + for (row, &c1) in str1.iter().enumerate() { + let mut db = 0; + for (col, &c2) in str2.iter().enumerate() { + let i = dict[c2 as usize]; + let k = db; + let cost = if c1 == c2 { 0 } else { 1 }; + if cost == 0 { + db = col + 1; + } + + matrix[row + 2][col + 2] = std::cmp::min( + std::cmp::min( + matrix[row + 1][col + 1] + cost, + matrix[row + 2][col + 1] + 1, + ), + std::cmp::min( + matrix[row + 1][col + 2] + 1, + matrix[i][k] + (row + 1 - i - 1) + (col + 1 - k - 1) + 1, + ), + ); + } + dict[c1 as usize] = row + 1; + } + + matrix[rows - 1][cols - 1] as i64 +} +// +// fuzzy_editdist(A,B) +// +// Return the cost of transforming string A into string B. Both strings +// must be pure ASCII text. If A ends with '*' then it is assumed to be +// a prefix of B and extra characters on the end of B have minimal additional +// cost. +// +#[scalar(name = "fuzzy_editdist")] +pub fn edit_distance(args: &[Value]) { + let Some(arg1) = args[0].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let Some(arg2) = args[1].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + if let Ok(res) = editdist::edit_distance(arg1, arg2) { + return Value::from_integer(res as i64); + } else { + return Value::error(ResultCode::InvalidArgs); + } +} + +// returns the hamming distance between two strings +#[scalar(name = "fuzzy_hamming")] +fn hamming(args: &[Value]) { + let Some(arg1) = args[0].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let Some(arg2) = args[1].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let dist = hamming_dist(arg1, arg2); + return Value::from_integer(dist); +} + +fn hamming_dist(s1: &str, s2: &str) -> i64 { + let str1_b = s1.as_bytes(); + let str2_b = s2.as_bytes(); + + if str1_b.len() != str2_b.len() { + return -1_i64; + } + + let res = str1_b + .iter() + .zip(str2_b.iter()) + .filter(|(a, b)| a != b) + .count(); + + res as i64 +} +#[scalar(name = "fuzzy_jarowin")] +fn jaronwin(args: &[Value]) { + let Some(arg1) = args[0].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let Some(arg2) = args[1].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let res = jaro_winkler(arg1, arg2); + return Value::from_float(res); +} + +/// Calculates and returns the Jaro-Winkler distance of two non NULL strings. +fn jaro_winkler(s1: &str, s2: &str) -> f64 { + let dist = jaro(s1, s2); + + let mut prefix_len = 0; + for (c1, c2) in s1.chars().zip(s2.chars()) { + if c1 == c2 { + prefix_len += 1; + } else { + break; + } + + if prefix_len == 3 { + break; + } + } + + dist + (prefix_len as f64) * 0.1 * (1.0 - dist) +} + +/// Calculates and returns the Jaro distance of two non NULL strings. +fn jaro(s1: &str, s2: &str) -> f64 { + if s1 == s2 { + return 1.0; + } + + let s1: Vec = s1.chars().collect(); + let s2: Vec = s2.chars().collect(); + + let len1 = s1.len(); + let len2 = s2.len(); + + if len1 == 0 || len2 == 0 { + return 0.0; + } + + let max_dist = (cmp::max(len1, len2) / 2).saturating_sub(1); + let mut match_count = 0; + + let mut hash_s1 = vec![false; len1]; + let mut hash_s2 = vec![false; len2]; + + for i in 0..len1 { + let start = i.saturating_sub(max_dist); + let end = cmp::min(i + max_dist + 1, len2); + + for j in start..end { + if s1[i] == s2[j] && !hash_s2[j] { + hash_s1[i] = true; + hash_s2[j] = true; + match_count += 1; + break; + } + } + } + + if match_count == 0 { + return 0.0; + } + + let mut t = 0; + let mut point = 0; + + for i in 0..len1 { + if hash_s1[i] { + while point < len2 && !hash_s2[point] { + point += 1; + } + if point < len2 && s1[i] != s2[point] { + t += 1; + } + point += 1; + } + } + + let t = t as f64 / 2.0; + let match_count = match_count as f64; + + (match_count / len1 as f64 + match_count / len2 as f64 + (match_count - t) / match_count) / 3.0 +} + +/// Computes and returns the Optimal String Alignment distance for two non NULL +#[scalar(name = "fuzzy_osadist")] +pub fn osadist(args: &[Value]) { + let Some(arg1) = args[0].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let Some(arg2) = args[1].to_text() else { + return Value::error(ResultCode::InvalidArgs); + }; + + let dist = optimal_string_alignment(arg1, arg2); + return Value::from_integer(dist as i64); +} + +fn optimal_string_alignment(s1: &str, s2: &str) -> usize { + let mut s1_chars: Vec = s1.chars().collect(); + let mut s2_chars: Vec = s2.chars().collect(); + + let mut len1 = s1_chars.len(); + let mut len2 = s2_chars.len(); + + while len1 > 0 && len2 > 0 && s1_chars[0] == s2_chars[0] { + s1_chars.remove(0); + s2_chars.remove(0); + len1 -= 1; + len2 -= 1; + } + + if len1 == 0 { + return len2; + } + if len2 == 0 { + return len1; + } + + let mut matrix = vec![vec![0usize; len2 + 1]; len1 + 1]; + + // clippy from this + //for i in 0..=len1 { + // matrix[i][0] = i; + //} + //for j in 0..=len2 { + // matrix[0][j] = j; + //} + // to + for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) { + row[0] = i; + } + + for (j, item) in matrix[0].iter_mut().enumerate().take(len2 + 1) { + *item = j; + } + + for i in 1..=len1 { + for j in 1..=len2 { + let cost = if s1_chars[i - 1] == s2_chars[j - 1] { + 0 + } else { + 1 + }; + + let deletion = matrix[i - 1][j] + 1; + let insertion = matrix[i][j - 1] + 1; + let substitution = matrix[i - 1][j - 1] + cost; + + matrix[i][j] = deletion.min(insertion).min(substitution); + + if i > 1 + && j > 1 + && s1_chars[i % len1] == s2_chars[j - 2] + && s1_chars[i - 2] == s2_chars[j % len2] + { + matrix[i][j] = matrix[i][j].min(matrix[i - 2][j - 2] + cost); + } + } + } + + matrix[len1][len2] +} + +//tests adapted from sqlean fuzzy +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_damlev() { + let cases = vec![ + ("abc", "abc", 0), + ("abc", "", 3), + ("", "abc", 3), + ("abc", "ab", 1), + ("abc", "abcd", 1), + ("abc", "acb", 1), + ("abc", "ca", 2), + ]; + + for (s1, s2, expected) in cases { + let got = damlev(s1, s2); + assert_eq!(got, expected, "damlev({}, {}) failed", s1, s2); + } + } + + #[test] + fn test_hamming() { + let cases = vec![ + ("abc", "abc", 0), + ("abc", "", -1), + ("", "abc", -1), + ("hello", "hellp", 1), + ("hello", "heloh", 2), + ]; + + for (s1, s2, expected) in cases { + let got = hamming_dist(s1, s2); + assert_eq!(got, expected, "hamming({}, {}) failed", s1, s2); + } + } + + #[test] + fn test_jaro_win() { + let cases: Vec<(&str, &str, f64)> = vec![ + ("abc", "abc", 1.0), + ("abc", "", 0.0), + ("", "abc", 0.0), + ("my string", "my tsring", 0.974), + ("my string", "my ntrisg", 0.896), + ]; + + for (s1, s2, expected) in cases { + let got = jaro_winkler(s1, s2); + + if (expected - 0.974).abs() < 1e-6 || (expected - 0.896).abs() < 1e-6 { + let got_rounded = (got * 1000.0).round() / 1000.0; + assert!( + (got_rounded - expected).abs() < 1e-6, + "jaro_winkler({}, {}) failed: got {}, expected {}", + s1, + s2, + got_rounded, + expected + ); + } else { + assert!( + (got - expected).abs() < 1e-6, + "jaro_winkler({}, {}) failed: got {}, expected {}", + s1, + s2, + got, + expected + ); + } + } + } + + #[test] + fn test_leven() { + let cases = vec![ + ("abc", "abc", 0), + ("abc", "", 3), + ("", "abc", 3), + ("abc", "ab", 1), + ("abc", "abcd", 1), + ("abc", "acb", 2), + ("abc", "ca", 3), + ]; + + for (s1, s2, expected) in cases { + let got = leven(s1, s2); + assert_eq!(got, expected, "leven({}, {}) failed", s1, s2); + } + } + + #[test] + fn test_edit_distance() { + let test_cases = vec![ + ("abc", "abc", 0), + ("abc", "", 300), + ("", "abc", 75), + ("abc", "ab", 100), + ("abc", "abcd", 25), + ("abc", "acb", 110), + ("abc", "ca", 225), + //more cases + ("awesome", "aewsme", 215), + ("kitten", "sitting", 105), + ("flaw", "lawn", 110), + ("rust", "trust", 100), + ("gumbo", "gambol", 65), + ]; + for (s1, s2, expected) in test_cases { + let res = editdist::edit_distance(s1, s2).unwrap(); + assert_eq!(res, expected, "edit_distance({}, {}) failed", s1, s2); + } + } + + #[test] + fn test_osadist() { + let cases = vec![ + ("abc", "abc", 0), + ("abc", "", 3), + ("", "abc", 3), + ("abc", "ab", 1), + ("abc", "abcd", 1), + ("abc", "acb", 2), + ("abc", "ca", 3), + ]; + + for (s1, s2, expected) in cases { + let got = optimal_string_alignment(s1, s2); + assert_eq!(got, expected, "osadist({}, {}) failed", s1, s2); + } + } +} diff --git a/testing/cli_tests/extensions.py b/testing/cli_tests/extensions.py index f53621fb9..332bd666c 100755 --- a/testing/cli_tests/extensions.py +++ b/testing/cli_tests/extensions.py @@ -560,6 +560,71 @@ def test_ipaddr(): ) limbo.quit() +def validate_fuzzy_leven(a): + return a == "3" + +def validate_fuzzy_damlev1(a): + return a == "2" + +def validate_fuzzy_damlev2(a): + return a == "1" + +def validate_fuzzy_editdist1(a): + return a == "225" + +def validate_fuzzy_editdist2(a): + return a == "110" + +def validate_fuzzy_jarowin(a): + return a == "0.907142857142857" + +def validate_fuzzy_osadist(a): + return a == "3" + +def test_fuzzy(): + limbo = TestTursoShell() + ext_path = "./target/debug/liblimbo_fuzzy" + limbo.run_test_fn( + "SELECT fuzzy_leven('awesome', 'aewsme');", + lambda res: "error: no such function: " in res, + "fuzzy levenshtein function returns null when ext not loaded", + ) + limbo.execute_dot(f".load {ext_path}") + limbo.run_test_fn( + "SELECT fuzzy_leven('awesome', 'aewsme');", + validate_fuzzy_leven, + "fuzzy levenshtein function works", + ) + limbo.run_test_fn( + "SELECT fuzzy_damlev('awesome', 'aewsme');", + validate_fuzzy_damlev1, + "fuzzy damerau levenshtein1 function works", + ) + limbo.run_test_fn( + "SELECT fuzzy_damlev('Something', 'Smoething');", + validate_fuzzy_damlev2, + "fuzzy damerau levenshtein2 function works", + ) + limbo.run_test_fn( + "SELECT fuzzy_editdist('abc', 'ca');", + validate_fuzzy_editdist1, + "fuzzy editdist1 function works", + ) + limbo.run_test_fn( + "SELECT fuzzy_editdist('abc', 'acb');", + validate_fuzzy_editdist2, + "fuzzy editdist2 function works", + ) + limbo.run_test_fn( + "SELECT fuzzy_jarowin('awesome', 'aewsme');", + validate_fuzzy_jarowin, + "fuzzy jarowin function works", + ) + limbo.run_test_fn( + "SELECT fuzzy_osadist('awesome', 'aewsme');", + validate_fuzzy_osadist, + "fuzzy osadist function works", + ) def test_vfs(): limbo = TestTursoShell() @@ -822,6 +887,7 @@ def main(): test_kv() test_csv() test_tablestats() + test_fuzzy() except Exception as e: console.error(f"Test FAILED: {e}") cleanup()