Fix ownership semantics in extention value conversions

This commit is contained in:
PThorpe92
2025-02-17 08:15:57 -05:00
parent 38e54ca85e
commit 4d2044b010
5 changed files with 41 additions and 97 deletions

View File

@@ -11,7 +11,7 @@ type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction);
#[derive(Clone)]
pub struct VTabImpl {
pub module_type: VTabKind,
pub module_kind: VTabKind,
pub implementation: Rc<VTabModuleImpl>,
}
@@ -104,7 +104,7 @@ impl Database {
) -> ResultCode {
let module = Rc::new(module);
let vmodule = VTabImpl {
module_type: kind,
module_kind: kind,
implementation: module,
};
self.syms

View File

@@ -542,7 +542,7 @@ impl VirtualTable {
module_name
)))?;
if let VTabKind::VirtualTable = kind {
if module.module_type != VTabKind::VirtualTable {
if module.module_kind != VTabKind::VirtualTable {
return Err(LimboError::ExtensionError(format!(
"Virtual table module {} is not a virtual table",
module_name
@@ -612,10 +612,7 @@ impl VirtualTable {
pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result<OwnedValue> {
let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) };
let res = OwnedValue::from_ffi(&val)?;
unsafe {
val.free();
}
let res = OwnedValue::from_ffi(val)?;
Ok(res)
}

View File

@@ -308,13 +308,7 @@ fn emit_program_for_delete(
&plan.table_references,
&plan.where_clause,
)?;
if let Some(table) = plan.table_references.first() {
if table.virtual_table().is_some() {
emit_delete_vtable_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?;
} else {
emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?;
}
}
emit_delete_insns(program, &mut t_ctx, &plan.table_references, &plan.limit)?;
// Clean up and close the main execution loop
close_loop(program, &mut t_ctx, &plan.table_references)?;
@@ -328,77 +322,6 @@ fn emit_program_for_delete(
Ok(())
}
fn emit_delete_vtable_insns(
program: &mut ProgramBuilder,
t_ctx: &mut TranslateCtx,
table_references: &[TableReference],
limit: &Option<isize>,
) -> Result<()> {
let table_reference = table_references.first().unwrap();
let cursor_id = match &table_reference.op {
Operation::Scan { .. } => program.resolve_cursor_id(&table_reference.identifier),
Operation::Search(search) => match search {
Search::RowidEq { .. } | Search::RowidSearch { .. } => {
program.resolve_cursor_id(&table_reference.identifier)
}
Search::IndexSearch { index, .. } => program.resolve_cursor_id(&index.name),
},
_ => return Ok(()),
};
let rowid_reg = program.alloc_register();
program.emit_insn(Insn::RowId {
cursor_id,
dest: rowid_reg,
});
// if we have a limit, decrement and check zero
if let Some(limit) = limit {
let limit_reg = program.alloc_register();
program.emit_insn(Insn::Integer {
value: *limit as i64,
dest: limit_reg,
});
program.mark_last_insn_constant();
program.emit_insn(Insn::DecrJumpZero {
reg: limit_reg,
target_pc: t_ctx.label_main_loop_end.unwrap(),
});
}
// we want old_rowid= rowid_reg, new_rowid= NULL, so we pass 2 arguments to VUpdate
// we need a second register for the new rowid = NULL
let new_rowid_reg = program.alloc_register();
program.emit_insn(Insn::Null {
dest: new_rowid_reg,
dest_end: None,
});
// we'll do VUpdate with arg_count=2:
// argv[0] => old_rowid = rowid_reg
// argv[1] => new_rowid = new_rowid_reg (NULL)
let Some(virtual_table) = table_reference.virtual_table() else {
return Err(crate::LimboError::ParseError(
"Table is not a virtual table".to_string(),
));
};
let conflict_action = 0u16;
let start_reg = rowid_reg;
program.emit_insn(Insn::VUpdate {
cursor_id,
arg_count: 2,
start_reg,
vtab_ptr: virtual_table.implementation.as_ref().ctx as usize,
conflict_action,
});
Ok(())
}
fn emit_delete_insns(
program: &mut ProgramBuilder,
t_ctx: &mut TranslateCtx,
@@ -423,8 +346,27 @@ fn emit_delete_insns(
cursor_id,
dest: key_reg,
});
program.emit_insn(Insn::DeleteAsync { cursor_id });
program.emit_insn(Insn::DeleteAwait { cursor_id });
if let Some(vtab) = table_reference.virtual_table() {
let conflict_action = 0u16;
let start_reg = key_reg;
let new_rowid_reg = program.alloc_register();
program.emit_insn(Insn::Null {
dest: new_rowid_reg,
dest_end: None,
});
program.emit_insn(Insn::VUpdate {
cursor_id,
arg_count: 2,
start_reg,
vtab_ptr: vtab.implementation.as_ref().ctx as usize,
conflict_action,
});
} else {
program.emit_insn(Insn::DeleteAsync { cursor_id });
program.emit_insn(Insn::DeleteAwait { cursor_id });
}
if let Some(limit) = limit {
let limit_reg = program.alloc_register();
program.emit_insn(Insn::Integer {

View File

@@ -223,8 +223,8 @@ impl OwnedValue {
}
}
pub fn from_ffi(v: &ExtValue) -> Result<Self> {
match v.value_type() {
pub fn from_ffi(v: ExtValue) -> Result<Self> {
let res = match v.value_type() {
ExtValueType::Null => Ok(OwnedValue::Null),
ExtValueType::Integer => {
let Some(int) = v.to_integer() else {
@@ -259,7 +259,11 @@ impl OwnedValue {
(code, None) => Err(LimboError::ExtensionError(code.to_string())),
}
}
};
unsafe {
v.free();
}
res
}
}
@@ -281,8 +285,7 @@ impl AggContext {
if let Self::External(ext_state) = self {
if ext_state.finalized_value.is_none() {
let final_value = unsafe { (ext_state.finalize_fn)(ext_state.state) };
ext_state.cache_final_value(OwnedValue::from_ffi(&final_value)?);
unsafe { final_value.free() };
ext_state.cache_final_value(OwnedValue::from_ffi(final_value)?);
}
}
Ok(())

View File

@@ -169,13 +169,11 @@ macro_rules! call_external_function {
) => {{
if $arg_count == 0 {
let result_c_value: ExtValue = unsafe { ($func_ptr)(0, std::ptr::null()) };
match OwnedValue::from_ffi(&result_c_value) {
match OwnedValue::from_ffi(result_c_value) {
Ok(result_ov) => {
$state.registers[$dest_register] = result_ov;
unsafe { result_c_value.free() };
}
Err(e) => {
unsafe { result_c_value.free() };
return Err(e);
}
}
@@ -188,13 +186,14 @@ macro_rules! call_external_function {
}
let argv_ptr = ext_values.as_ptr();
let result_c_value: ExtValue = unsafe { ($func_ptr)($arg_count as i32, argv_ptr) };
match OwnedValue::from_ffi(&result_c_value) {
for arg in ext_values {
unsafe { arg.free() };
}
match OwnedValue::from_ffi(result_c_value) {
Ok(result_ov) => {
$state.registers[$dest_register] = result_ov;
unsafe { result_c_value.free() };
}
Err(e) => {
unsafe { result_c_value.free() };
return Err(e);
}
}
@@ -1858,6 +1857,9 @@ impl Program {
}
let argv_ptr = ext_values.as_ptr();
unsafe { step_fn(state_ptr, argc as i32, argv_ptr) };
for ext_value in ext_values {
unsafe { ext_value.free() };
}
}
}
};