From 23d6080531e32241bf1c138ec9e98a119bf4daf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Francoeur?= Date: Mon, 10 Nov 2025 19:20:55 -0500 Subject: [PATCH] make FromClause recursive --- simulator/model/mod.rs | 36 ++++++++++++----- sql_generation/generation/query.rs | 13 +++---- sql_generation/model/query/select.rs | 58 +++++++++++++++++++++------- 3 files changed, 76 insertions(+), 31 deletions(-) diff --git a/simulator/model/mod.rs b/simulator/model/mod.rs index 2bf95c660..116e322e7 100644 --- a/simulator/model/mod.rs +++ b/simulator/model/mod.rs @@ -4,6 +4,7 @@ use anyhow::Context; use bitflags::bitflags; use indexmap::IndexSet; use serde::{Deserialize, Serialize}; +use sql_generation::model::query::select::SelectTable; use sql_generation::model::{ query::{ alter_table::{AlterTable, AlterTableType}, pragma::Pragma, select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, transaction::{Begin, Commit, Rollback}, update::Update, Create, CreateIndex, @@ -236,7 +237,7 @@ impl From for QueryCapabilities { } impl QueryDiscriminants { - pub const ALL_NO_TRANSACTION: &[QueryDiscriminants] = &[ + pub const ALL_NO_TRANSACTION: &'_ [QueryDiscriminants] = &[ QueryDiscriminants::Select, QueryDiscriminants::Create, QueryDiscriminants::Insert, @@ -354,16 +355,33 @@ impl Shadow for Insert { impl Shadow for FromClause { type Result = anyhow::Result; fn shadow(&self, tables: &mut ShadowTablesMut) -> Self::Result { - let first_table = tables - .iter() - .find(|t| t.name == self.table) - .context("Table not found")?; - - let mut join_table = JoinTable { - tables: vec![first_table.clone()], - rows: first_table.rows.clone(), + let mut join_table = match &self.table { + SelectTable::Table(table) => { + let first_table = tables + .iter() + .find(|t| t.name == *table) + .context("Table not found")?; + JoinTable { + tables: vec![first_table.clone()], + rows: first_table.rows.clone(), + } + } + SelectTable::Select(select) => { + let select_dependencies = select.dependencies(); + let result_tables = tables + .iter() + .filter(|shadow_table| select_dependencies.contains(shadow_table.name.as_str())) + .cloned() + .collect(); + let rows = select.shadow(tables)?; + JoinTable { + tables: result_tables, + rows, + } + } }; + for join in &self.joins { let joined_table = tables .iter() diff --git a/sql_generation/generation/query.rs b/sql_generation/generation/query.rs index 420659599..bd15321eb 100644 --- a/sql_generation/generation/query.rs +++ b/sql_generation/generation/query.rs @@ -4,10 +4,7 @@ use crate::generation::{ }; use crate::model::query::alter_table::{AlterTable, AlterTableType, AlterTableTypeDiscriminants}; use crate::model::query::predicate::Predicate; -use crate::model::query::select::{ - CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, - SelectInner, -}; +use crate::model::query::select::{CompoundOperator, CompoundSelect, Distinctness, FromClause, OrderBy, ResultColumn, SelectBody, SelectInner, SelectTable}; use crate::model::query::update::Update; use crate::model::query::{Create, CreateIndex, Delete, Drop, DropIndex, Insert, Select}; use crate::model::table::{ @@ -84,7 +81,7 @@ impl Arbitrary for FromClause { }) }) .collect(); - FromClause { table: name, joins } + FromClause { table: SelectTable::Table(name), joins } } } @@ -98,11 +95,12 @@ impl Arbitrary for SelectInner { let order_by = rng .random_bool(env.opts().query.select.order_by_prob) .then(|| { + let dependencies = &from.table.dependencies(); let order_by_table_candidates = from .joins .iter() .map(|j| &j.table) - .chain(std::iter::once(&from.table)) + .chain(dependencies) .collect::>(); let order_by_col_count = (rng.random::() * rng.random::() * (cuml_col_count as f64)) as usize; // skew towards 0 @@ -155,11 +153,12 @@ impl ArbitrarySized for SelectInner { ) -> Self { let mut select_inner = SelectInner::arbitrary(rng, env); let select_from = &select_inner.from.as_ref().unwrap(); + let dependencies = select_from.table.dependencies(); let table_names = select_from .joins .iter() .map(|j| &j.table) - .chain(std::iter::once(&select_from.table)); + .chain(&dependencies); let flat_columns_names = table_names .flat_map(|t| { diff --git a/sql_generation/model/query/select.rs b/sql_generation/model/query/select.rs index 3ba71c06f..f2e8ec96e 100644 --- a/sql_generation/model/query/select.rs +++ b/sql_generation/model/query/select.rs @@ -82,7 +82,7 @@ impl Select { distinctness: distinct, columns: result_columns, from: Some(FromClause { - table, + table: SelectTable::Table(table), joins: Vec::new(), }), where_clause, @@ -112,7 +112,6 @@ impl Select { } let from = self.body.select.from.as_ref().unwrap(); let mut tables = IndexSet::new(); - tables.insert(from.table.clone()); tables.extend(from.dependencies()); @@ -178,19 +177,28 @@ pub struct CompoundSelect { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct FromClause { /// table - pub table: String, + pub table: SelectTable, /// `JOIN`ed tables pub joins: Vec, } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum SelectTable { + Table(String), + Select(Select), +} + impl FromClause { fn to_sql_ast(&self) -> ast::FromClause { ast::FromClause { - select: Box::new(ast::SelectTable::Table( - ast::QualifiedName::single(ast::Name::from_string(&self.table)), - None, - None, - )), + select: Box::new(match &self.table { + SelectTable::Table(table) => ast::SelectTable::Table( + ast::QualifiedName::single(ast::Name::from_string(table)), + None, + None, + ), + SelectTable::Select(select) => ast::SelectTable::Select(select.to_sql_ast(), None), + }), joins: self .joins .iter() @@ -214,7 +222,7 @@ impl FromClause { } pub fn dependencies(&self) -> Vec { - let mut deps = vec![self.table.clone()]; + let mut deps = self.table.dependencies(); for join in &self.joins { deps.push(join.table.clone()); } @@ -222,9 +230,15 @@ impl FromClause { } pub fn into_join_table(&self, tables: &[Table]) -> JoinTable { + let self_table = if let SelectTable::Table(table) = &self.table { + table.clone() + } else { + unimplemented!("into_join_table is only implemented for Table"); + }; + let first_table = tables .iter() - .find(|t| t.name == self.table) + .find(|t| t.name == self_table) .expect("Table not found"); let mut join_table = JoinTable { @@ -368,9 +382,23 @@ impl Display for Select { } } -#[cfg(test)] -mod select_tests { - - #[test] - fn test_select_display() {} +impl SelectTable { + pub fn dependencies(&self) -> Vec { + match self { + SelectTable::Table(table) => vec![table.to_owned()], + SelectTable::Select(select) => { + if let Some(from) = &select.body.select.from { + let mut dependencies = from.table.dependencies(); + dependencies.extend( + from.joins + .iter() + .map(|joined_table| joined_table.table.clone()), + ); + dependencies + } else { + vec![] + } + } + } + } }