mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-29 22:14:23 +01:00
277 lines
7.2 KiB
Rust
277 lines
7.2 KiB
Rust
// Adapted from SQLite spellfix.c extension and sqlean fuzzy/editdist.c
|
|
use crate::common::*;
|
|
|
|
#[derive(Debug, PartialEq)]
|
|
pub enum EditDistanceError {
|
|
NonAsciiInput,
|
|
}
|
|
|
|
pub type EditDistanceResult = Result<i32, EditDistanceError>;
|
|
|
|
fn character_class(c_prev: u8, c: u8) -> u8 {
|
|
if c_prev == 0 {
|
|
INIT_CLASS[(c & 0x7f) as usize]
|
|
} else {
|
|
MID_CLASS[(c & 0x7f) as usize]
|
|
}
|
|
}
|
|
|
|
/// Return the cost of inserting or deleting character c immediately
|
|
/// following character c_prev. If c_prev == 0, that means c is the first
|
|
/// character of the word.
|
|
fn insert_or_delete_cost(c_prev: u8, c: u8, c_next: u8) -> i32 {
|
|
let class_c = character_class(c_prev, c);
|
|
|
|
if class_c == CCLASS_SILENT {
|
|
return 1;
|
|
}
|
|
|
|
if c_prev == c {
|
|
return 10;
|
|
}
|
|
|
|
if class_c == CCLASS_VOWEL && (c_prev == b'r' || c_next == b'r') {
|
|
return 20; // Insert a vowel before or after 'r'
|
|
}
|
|
|
|
let class_c_prev = character_class(c_prev, c_prev);
|
|
if class_c == class_c_prev {
|
|
if class_c == CCLASS_VOWEL {
|
|
15
|
|
} else {
|
|
50
|
|
}
|
|
} else {
|
|
// Any other character insertion or deletion
|
|
100
|
|
}
|
|
}
|
|
|
|
const FINAL_INS_COST_DIV: i32 = 4;
|
|
|
|
/// Return the cost of substituting c_to in place of c_from assuming
|
|
/// the previous character is c_prev. If c_prev == 0 then c_to is the first
|
|
/// character of the word.
|
|
fn substitute_cost(c_prev: u8, c_from: u8, c_to: u8) -> i32 {
|
|
if c_from == c_to {
|
|
return 0;
|
|
}
|
|
|
|
if c_from == (c_to ^ 0x20) && c_to.is_ascii_alphabetic() {
|
|
return 0;
|
|
}
|
|
|
|
let class_from = character_class(c_prev, c_from);
|
|
let class_to = character_class(c_prev, c_to);
|
|
|
|
if class_from == class_to {
|
|
40
|
|
} else if (CCLASS_B..=CCLASS_Y).contains(&class_from)
|
|
&& (CCLASS_B..=CCLASS_Y).contains(&class_to)
|
|
{
|
|
75
|
|
} else {
|
|
100
|
|
}
|
|
}
|
|
|
|
/// Given two strings z_a and z_b which are pure ASCII, return the cost
|
|
/// of transforming z_a into z_b. If z_a ends with '*' assume that it is
|
|
/// a prefix of z_b and give only minimal penalty for extra characters
|
|
/// on the end of z_b.
|
|
///
|
|
/// Returns cost where smaller numbers mean a closer match
|
|
///
|
|
/// Returns Err for Non-ASCII characters on input
|
|
pub fn edit_distance(z_a: &str, z_b: &str) -> EditDistanceResult {
|
|
if z_a.is_empty() && z_b.is_empty() {
|
|
return Ok(0);
|
|
}
|
|
|
|
let za_bytes = z_a.as_bytes();
|
|
let zb_bytes = z_b.as_bytes();
|
|
|
|
if !z_a.is_ascii() || !z_b.is_ascii() {
|
|
return Err(EditDistanceError::NonAsciiInput);
|
|
}
|
|
|
|
if z_a.is_empty() {
|
|
let mut res = 0;
|
|
let mut c_b_prev = 0u8;
|
|
let zb_bytes = z_b.as_bytes();
|
|
|
|
for (i, &c_b) in zb_bytes.iter().enumerate() {
|
|
let c_b_next = if i + 1 < zb_bytes.len() {
|
|
zb_bytes[i + 1]
|
|
} else {
|
|
0
|
|
};
|
|
res += insert_or_delete_cost(c_b_prev, c_b, c_b_next) / FINAL_INS_COST_DIV;
|
|
c_b_prev = c_b;
|
|
}
|
|
return Ok(res);
|
|
}
|
|
|
|
if z_b.is_empty() {
|
|
let mut res = 0;
|
|
let mut c_a_prev = 0u8;
|
|
let za_bytes = z_a.as_bytes();
|
|
|
|
for (i, &c_a) in za_bytes.iter().enumerate() {
|
|
let c_a_next = if i + 1 < za_bytes.len() {
|
|
za_bytes[i + 1]
|
|
} else {
|
|
0
|
|
};
|
|
res += insert_or_delete_cost(c_a_prev, c_a, c_a_next);
|
|
c_a_prev = c_a;
|
|
}
|
|
return Ok(res);
|
|
}
|
|
|
|
let mut za_start = 0;
|
|
let mut zb_start = 0;
|
|
|
|
// Skip any common prefix
|
|
while za_start < za_bytes.len()
|
|
&& zb_start < zb_bytes.len()
|
|
&& za_bytes[za_start] == zb_bytes[zb_start]
|
|
{
|
|
za_start += 1;
|
|
zb_start += 1;
|
|
}
|
|
|
|
// If both strings are exhausted after common prefix
|
|
if za_start >= za_bytes.len() && zb_start >= zb_bytes.len() {
|
|
return Ok(0);
|
|
}
|
|
|
|
let za_remaining = &za_bytes[za_start..];
|
|
let zb_remaining = &zb_bytes[zb_start..];
|
|
let n_a = za_remaining.len();
|
|
let n_b = zb_remaining.len();
|
|
|
|
// Special processing if either remaining string is empty after prefix matching
|
|
if n_a == 0 {
|
|
let mut res = 0;
|
|
let mut c_b_prev = if za_start > 0 {
|
|
za_bytes[za_start - 1]
|
|
} else {
|
|
0
|
|
};
|
|
|
|
for (i, &c_b) in zb_remaining.iter().enumerate() {
|
|
let c_b_next = if i + 1 < n_b { zb_remaining[i + 1] } else { 0 };
|
|
res += insert_or_delete_cost(c_b_prev, c_b, c_b_next) / FINAL_INS_COST_DIV;
|
|
c_b_prev = c_b;
|
|
}
|
|
return Ok(res);
|
|
}
|
|
|
|
if n_b == 0 {
|
|
let mut res = 0;
|
|
let mut c_a_prev = if za_start > 0 {
|
|
za_bytes[za_start - 1]
|
|
} else {
|
|
0
|
|
};
|
|
|
|
for (i, &c_a) in za_remaining.iter().enumerate() {
|
|
let c_a_next = if i + 1 < n_a { za_remaining[i + 1] } else { 0 };
|
|
res += insert_or_delete_cost(c_a_prev, c_a, c_a_next);
|
|
c_a_prev = c_a;
|
|
}
|
|
return Ok(res);
|
|
}
|
|
|
|
// Check if a is a prefix pattern
|
|
if za_remaining.len() == 1 && za_remaining[0] == b'*' {
|
|
return Ok(0);
|
|
}
|
|
|
|
let mut m = vec![0i32; n_b + 1];
|
|
let mut cx = vec![0u8; n_b + 1];
|
|
|
|
let dc = if za_start > 0 {
|
|
za_bytes[za_start - 1]
|
|
} else {
|
|
0
|
|
};
|
|
m[0] = 0;
|
|
cx[0] = dc;
|
|
|
|
let mut c_b_prev = dc;
|
|
for x_b in 1..=n_b {
|
|
let c_b = zb_remaining[x_b - 1];
|
|
let c_b_next = if x_b < n_b { zb_remaining[x_b] } else { 0 };
|
|
cx[x_b] = c_b;
|
|
m[x_b] = m[x_b - 1] + insert_or_delete_cost(c_b_prev, c_b, c_b_next);
|
|
c_b_prev = c_b;
|
|
}
|
|
|
|
let mut c_a_prev = dc;
|
|
for x_a in 1..=n_a {
|
|
let last_a = x_a == n_a;
|
|
let c_a = za_remaining[x_a - 1];
|
|
let c_a_next = if x_a < n_a { za_remaining[x_a] } else { 0 };
|
|
|
|
if c_a == b'*' && last_a {
|
|
break;
|
|
}
|
|
|
|
let mut d = m[0];
|
|
m[0] = d + insert_or_delete_cost(c_a_prev, c_a, c_a_next);
|
|
|
|
for x_b in 1..=n_b {
|
|
let c_b = zb_remaining[x_b - 1];
|
|
let c_b_next = if x_b < n_b { zb_remaining[x_b] } else { 0 };
|
|
|
|
// Cost to insert c_b
|
|
let mut ins_cost = insert_or_delete_cost(cx[x_b - 1], c_b, c_b_next);
|
|
if last_a {
|
|
ins_cost /= FINAL_INS_COST_DIV;
|
|
}
|
|
|
|
// Cost to delete c_a
|
|
let del_cost = insert_or_delete_cost(cx[x_b], c_a, c_b_next);
|
|
|
|
// Cost to substitute c_a -> c_b
|
|
let sub_cost = substitute_cost(cx[x_b - 1], c_a, c_b);
|
|
|
|
// Find best cost
|
|
let mut total_cost = ins_cost + m[x_b - 1];
|
|
let mut ncx = c_b;
|
|
|
|
if del_cost + m[x_b] < total_cost {
|
|
total_cost = del_cost + m[x_b];
|
|
ncx = c_a;
|
|
}
|
|
|
|
if sub_cost + d < total_cost {
|
|
total_cost = sub_cost + d;
|
|
}
|
|
|
|
d = m[x_b];
|
|
m[x_b] = total_cost;
|
|
cx[x_b] = ncx;
|
|
}
|
|
c_a_prev = c_a;
|
|
}
|
|
|
|
let res = if za_remaining.last() == Some(&b'*') {
|
|
let mut min_cost = m[1];
|
|
|
|
for &val in m.iter().skip(1).take(n_b) {
|
|
if val < min_cost {
|
|
min_cost = val;
|
|
}
|
|
}
|
|
|
|
min_cost
|
|
} else {
|
|
m[n_b]
|
|
};
|
|
|
|
Ok(res)
|
|
}
|