Change get_column_mapping to return an Option now that we support excluded.col in upsert

This commit is contained in:
PThorpe92
2025-08-29 20:36:03 -04:00
parent c659a0e4d4
commit e4a0a57227
2 changed files with 50 additions and 46 deletions

View File

@@ -417,7 +417,7 @@ pub fn translate_insert(
let mut rewritten_sets =
collect_set_clauses_for_upsert(&table, sets, &insertion)?;
if let Some(expr) = where_clause.as_mut() {
rewrite_excluded_in_expr(expr, &insertion);
rewrite_excluded_in_expr(expr, &insertion)?;
}
emit_upsert(
&mut program,
@@ -474,8 +474,13 @@ pub fn translate_insert(
// copy each index column from the table's column registers into these scratch regs
for (i, column_mapping) in column_mappings.clone().enumerate() {
// copy from the table's column register over to the index's scratch register
let Some(col_mapping) = column_mapping else {
return Err(crate::LimboError::PlanningError(
"Column not found in INSERT".to_string(),
));
};
program.emit_insn(Insn::Copy {
src_reg: column_mapping.register,
src_reg: col_mapping.register,
dst_reg: idx_start_reg + i,
extra_amount: 0,
});
@@ -533,7 +538,7 @@ pub fn translate_insert(
let mut rewritten_sets =
collect_set_clauses_for_upsert(&table, sets, &insertion)?;
if let Some(expr) = where_clause.as_mut() {
rewrite_excluded_in_expr(expr, &insertion);
rewrite_excluded_in_expr(expr, &insertion)?;
}
let conflict_rowid_reg = program.alloc_register();
@@ -744,7 +749,7 @@ impl<'a> Insertion<'a> {
}
/// Returns the column mapping for a given column name.
pub fn get_col_mapping_by_name(&self, name: &str) -> &ColMapping<'a> {
pub fn get_col_mapping_by_name(&self, name: &str) -> Option<&ColMapping<'a>> {
if let InsertionKey::RowidAlias(mapping) = &self.key {
// If the key is a rowid alias, a NULL is emitted as the column value,
// so we need to return the key mapping instead so that the non-NULL rowid is used
@@ -755,18 +760,15 @@ impl<'a> Insertion<'a> {
.as_ref()
.is_some_and(|n| n.eq_ignore_ascii_case(name))
{
return mapping;
return Some(mapping);
}
}
self.col_mappings
.iter()
.find(|col| {
col.column
.name
.as_ref()
.is_some_and(|n| n.eq_ignore_ascii_case(name))
})
.unwrap_or_else(|| panic!("column {name} not found in insertion"))
self.col_mappings.iter().find(|col| {
col.column
.name
.as_ref()
.is_some_and(|n| n.eq_ignore_ascii_case(name))
})
}
}

View File

@@ -92,34 +92,35 @@ fn effective_collation_for_index_col(idx_col: &IndexColumn, table: &Table) -> St
/// https://sqlite.org/lang_upsert.html
///
/// Rewrite EXCLUDED.x to Expr::Register(<reg of x from insertion>)
pub fn rewrite_excluded_in_expr(expr: &mut Expr, insertion: &Insertion) {
pub fn rewrite_excluded_in_expr(expr: &mut Expr, insertion: &Insertion) -> crate::Result<()> {
match expr {
// EXCLUDED.x accept Qualified with left=excluded
Expr::Qualified(ns, col) if ns.as_str().eq_ignore_ascii_case("excluded") => {
let cname = match col {
ast::Name::Ident(s) => s.as_str(),
_ => return,
_ => return Ok(()),
};
let src = insertion.get_col_mapping_by_name(cname).register;
*expr = Expr::Register(src);
let Some(src) = insertion.get_col_mapping_by_name(cname) else {
bail_parse_error!("no such column in EXCLUDED: {}", cname);
};
*expr = Expr::Register(src.register)
}
Expr::Collate(inner, _) => rewrite_excluded_in_expr(inner, insertion),
Expr::Collate(inner, _) => rewrite_excluded_in_expr(inner, insertion)?,
Expr::Parenthesized(v) => {
for e in v {
rewrite_excluded_in_expr(e, insertion)
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::Between {
lhs, start, end, ..
} => {
rewrite_excluded_in_expr(lhs, insertion);
rewrite_excluded_in_expr(start, insertion);
rewrite_excluded_in_expr(end, insertion);
rewrite_excluded_in_expr(lhs, insertion)?;
rewrite_excluded_in_expr(start, insertion)?;
rewrite_excluded_in_expr(end, insertion)?;
}
Expr::Binary(l, _, r) => {
rewrite_excluded_in_expr(l, insertion);
rewrite_excluded_in_expr(r, insertion);
rewrite_excluded_in_expr(l, insertion)?;
rewrite_excluded_in_expr(r, insertion)?;
}
Expr::Case {
base,
@@ -127,17 +128,17 @@ pub fn rewrite_excluded_in_expr(expr: &mut Expr, insertion: &Insertion) {
else_expr,
} => {
if let Some(b) = base {
rewrite_excluded_in_expr(b, insertion)
}
rewrite_excluded_in_expr(b, insertion)?
};
for (w, t) in when_then_pairs.iter_mut() {
rewrite_excluded_in_expr(w, insertion);
rewrite_excluded_in_expr(t, insertion);
rewrite_excluded_in_expr(w, insertion)?;
rewrite_excluded_in_expr(t, insertion)?;
}
if let Some(e) = else_expr {
rewrite_excluded_in_expr(e, insertion)
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::Cast { expr: inner, .. } => rewrite_excluded_in_expr(inner, insertion),
Expr::Cast { expr: inner, .. } => rewrite_excluded_in_expr(inner, insertion)?,
Expr::FunctionCall {
args,
order_by,
@@ -145,37 +146,38 @@ pub fn rewrite_excluded_in_expr(expr: &mut Expr, insertion: &Insertion) {
..
} => {
for a in args {
rewrite_excluded_in_expr(a, insertion)
rewrite_excluded_in_expr(a, insertion)?
}
for sc in order_by {
rewrite_excluded_in_expr(&mut sc.expr, insertion)
rewrite_excluded_in_expr(&mut sc.expr, insertion)?
}
if let Some(ex) = &mut filter_over.filter_clause {
rewrite_excluded_in_expr(ex, insertion)
rewrite_excluded_in_expr(ex, insertion)?
}
}
Expr::InList { lhs, rhs, .. } => {
rewrite_excluded_in_expr(lhs, insertion);
rewrite_excluded_in_expr(lhs, insertion)?;
for e in rhs {
rewrite_excluded_in_expr(e, insertion)
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::InSelect { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion),
Expr::InTable { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion),
Expr::IsNull(inner) => rewrite_excluded_in_expr(inner, insertion),
Expr::InSelect { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion)?,
Expr::InTable { lhs, .. } => rewrite_excluded_in_expr(lhs, insertion)?,
Expr::IsNull(inner) => rewrite_excluded_in_expr(inner, insertion)?,
Expr::Like {
lhs, rhs, escape, ..
} => {
rewrite_excluded_in_expr(lhs, insertion);
rewrite_excluded_in_expr(rhs, insertion);
rewrite_excluded_in_expr(lhs, insertion)?;
rewrite_excluded_in_expr(rhs, insertion)?;
if let Some(e) = escape {
rewrite_excluded_in_expr(e, insertion)
rewrite_excluded_in_expr(e, insertion)?
}
}
Expr::NotNull(inner) => rewrite_excluded_in_expr(inner, insertion),
Expr::Unary(_, inner) => rewrite_excluded_in_expr(inner, insertion),
Expr::NotNull(inner) => rewrite_excluded_in_expr(inner, insertion)?,
Expr::Unary(_, inner) => rewrite_excluded_in_expr(inner, insertion)?,
_ => {}
}
Ok(())
}
/// Match ON CONFLICT target to the PRIMARY KEY, if any.
@@ -565,7 +567,7 @@ pub fn collect_set_clauses_for_upsert(
);
}
for (cn, mut e) in set.col_names.iter().zip(values.into_iter()) {
rewrite_excluded_in_expr(&mut e, insertion);
rewrite_excluded_in_expr(&mut e, insertion)?;
let Some(idx) = lookup.get(&normalize_ident(cn.as_str())) else {
bail_parse_error!("no such column: {}", cn);
};