make FromClause recursive

This commit is contained in:
Mikaël Francoeur
2025-11-10 19:20:55 -05:00
parent 156693ce95
commit 23d6080531
3 changed files with 76 additions and 31 deletions

View File

@@ -4,6 +4,7 @@ use anyhow::Context;
use bitflags::bitflags;
use indexmap::IndexSet;
use serde::{Deserialize, Serialize};
use sql_generation::model::query::select::SelectTable;
use sql_generation::model::{
query::{
alter_table::{AlterTable, AlterTableType}, pragma::Pragma, select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, transaction::{Begin, Commit, Rollback}, update::Update, Create, CreateIndex,
@@ -236,7 +237,7 @@ impl From<QueryDiscriminants> for QueryCapabilities {
}
impl QueryDiscriminants {
pub const ALL_NO_TRANSACTION: &[QueryDiscriminants] = &[
pub const ALL_NO_TRANSACTION: &'_ [QueryDiscriminants] = &[
QueryDiscriminants::Select,
QueryDiscriminants::Create,
QueryDiscriminants::Insert,
@@ -354,16 +355,33 @@ impl Shadow for Insert {
impl Shadow for FromClause {
type Result = anyhow::Result<JoinTable>;
fn shadow(&self, tables: &mut ShadowTablesMut) -> Self::Result {
let first_table = tables
.iter()
.find(|t| t.name == self.table)
.context("Table not found")?;
let mut join_table = JoinTable {
tables: vec![first_table.clone()],
rows: first_table.rows.clone(),
let mut join_table = match &self.table {
SelectTable::Table(table) => {
let first_table = tables
.iter()
.find(|t| t.name == *table)
.context("Table not found")?;
JoinTable {
tables: vec![first_table.clone()],
rows: first_table.rows.clone(),
}
}
SelectTable::Select(select) => {
let select_dependencies = select.dependencies();
let result_tables = tables
.iter()
.filter(|shadow_table| select_dependencies.contains(shadow_table.name.as_str()))
.cloned()
.collect();
let rows = select.shadow(tables)?;
JoinTable {
tables: result_tables,
rows,
}
}
};
for join in &self.joins {
let joined_table = tables
.iter()

View File

@@ -4,10 +4,7 @@ use crate::generation::{
};
use crate::model::query::alter_table::{AlterTable, AlterTableType, AlterTableTypeDiscriminants};
use crate::model::query::predicate::Predicate;
use crate::model::query::select::{
CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody,
SelectInner,
};
use crate::model::query::select::{CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, SelectInner, SelectTable};
use crate::model::query::update::Update;
use crate::model::query::{Create, CreateIndex, Delete, Drop, DropIndex, Insert, Select};
use crate::model::table::{
@@ -84,7 +81,7 @@ impl Arbitrary for FromClause {
})
})
.collect();
FromClause { table: name, joins }
FromClause { table: SelectTable::Table(name), joins }
}
}
@@ -98,11 +95,12 @@ impl Arbitrary for SelectInner {
let order_by = rng
.random_bool(env.opts().query.select.order_by_prob)
.then(|| {
let dependencies = &from.table.dependencies();
let order_by_table_candidates = from
.joins
.iter()
.map(|j| &j.table)
.chain(std::iter::once(&from.table))
.chain(dependencies)
.collect::<Vec<_>>();
let order_by_col_count =
(rng.random::<f64>() * rng.random::<f64>() * (cuml_col_count as f64)) as usize; // skew towards 0
@@ -155,11 +153,12 @@ impl ArbitrarySized for SelectInner {
) -> Self {
let mut select_inner = SelectInner::arbitrary(rng, env);
let select_from = &select_inner.from.as_ref().unwrap();
let dependencies = select_from.table.dependencies();
let table_names = select_from
.joins
.iter()
.map(|j| &j.table)
.chain(std::iter::once(&select_from.table));
.chain(&dependencies);
let flat_columns_names = table_names
.flat_map(|t| {

View File

@@ -82,7 +82,7 @@ impl Select {
distinctness: distinct,
columns: result_columns,
from: Some(FromClause {
table,
table: SelectTable::Table(table),
joins: Vec::new(),
}),
where_clause,
@@ -112,7 +112,6 @@ impl Select {
}
let from = self.body.select.from.as_ref().unwrap();
let mut tables = IndexSet::new();
tables.insert(from.table.clone());
tables.extend(from.dependencies());
@@ -178,19 +177,28 @@ pub struct CompoundSelect {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FromClause {
/// table
pub table: String,
pub table: SelectTable,
/// `JOIN`ed tables
pub joins: Vec<JoinedTable>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum SelectTable {
Table(String),
Select(Select),
}
impl FromClause {
fn to_sql_ast(&self) -> ast::FromClause {
ast::FromClause {
select: Box::new(ast::SelectTable::Table(
ast::QualifiedName::single(ast::Name::from_string(&self.table)),
None,
None,
)),
select: Box::new(match &self.table {
SelectTable::Table(table) => ast::SelectTable::Table(
ast::QualifiedName::single(ast::Name::from_string(table)),
None,
None,
),
SelectTable::Select(select) => ast::SelectTable::Select(select.to_sql_ast(), None),
}),
joins: self
.joins
.iter()
@@ -214,7 +222,7 @@ impl FromClause {
}
pub fn dependencies(&self) -> Vec<String> {
let mut deps = vec![self.table.clone()];
let mut deps = self.table.dependencies();
for join in &self.joins {
deps.push(join.table.clone());
}
@@ -222,9 +230,15 @@ impl FromClause {
}
pub fn into_join_table(&self, tables: &[Table]) -> JoinTable {
let self_table = if let SelectTable::Table(table) = &self.table {
table.clone()
} else {
unimplemented!("into_join_table is only implemented for Table");
};
let first_table = tables
.iter()
.find(|t| t.name == self.table)
.find(|t| t.name == self_table)
.expect("Table not found");
let mut join_table = JoinTable {
@@ -368,9 +382,23 @@ impl Display for Select {
}
}
#[cfg(test)]
mod select_tests {
#[test]
fn test_select_display() {}
impl SelectTable {
pub fn dependencies(&self) -> Vec<String> {
match self {
SelectTable::Table(table) => vec![table.to_owned()],
SelectTable::Select(select) => {
if let Some(from) = &select.body.select.from {
let mut dependencies = from.table.dependencies();
dependencies.extend(
from.joins
.iter()
.map(|joined_table| joined_table.table.clone()),
);
dependencies
} else {
vec![]
}
}
}
}
}