diff --git a/simulator/model/mod.rs b/simulator/model/mod.rs index 1ffa5a161..2bf95c660 100644 --- a/simulator/model/mod.rs +++ b/simulator/model/mod.rs @@ -3,16 +3,15 @@ use std::fmt::Display; use anyhow::Context; use bitflags::bitflags; use indexmap::IndexSet; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use sql_generation::model::{ query::{ - Create, CreateIndex, Delete, Drop, DropIndex, Insert, Select, - alter_table::{AlterTable, AlterTableType}, - pragma::Pragma, - select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, - transaction::{Begin, Commit, Rollback}, - update::Update, + alter_table::{AlterTable, AlterTableType}, pragma::Pragma, select::{CompoundOperator, FromClause, ResultColumn, SelectInner}, transaction::{Begin, Commit, Rollback}, update::Update, Create, CreateIndex, + Delete, + Drop, + DropIndex, + Insert, + Select, }, table::{Index, JoinTable, JoinType, SimValue, Table, TableContext}, }; @@ -362,7 +361,7 @@ impl Shadow for FromClause { let mut join_table = JoinTable { tables: vec![first_table.clone()], - rows: Vec::new(), + rows: first_table.rows.clone(), }; for join in &self.joins { @@ -375,29 +374,18 @@ impl Shadow for FromClause { match join.join_type { JoinType::Inner => { - // Implement inner join logic - let join_rows = joined_table - .rows - .iter() - .filter(|row| join.on.test(row, joined_table)) - .cloned() - .collect::>(); - // take a cartesian product of the rows - let all_row_pairs = join_table - .rows - .clone() - .into_iter() - .cartesian_product(join_rows.iter()); - - for (row1, row2) in all_row_pairs { - let row = row1.iter().chain(row2.iter()).cloned().collect::>(); - - let is_in = join.on.test(&row, &join_table); - - if is_in { - join_table.rows.push(row); + let prev_rows = std::mem::take(&mut join_table.rows); + let mut new_rows = Vec::new(); + for row1 in prev_rows.into_iter() { + for row2 in joined_table.rows.iter() { + let combined_row = + row1.iter().chain(row2.iter()).cloned().collect::>(); + if join.on.test(&combined_row, &join_table) { + new_rows.push(combined_row); + } } } + join_table.rows = new_rows; } _ => todo!(), }