diff --git a/core/lib.rs b/core/lib.rs index 29b2bb27e..0781c52ee 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -16,8 +16,7 @@ use fallible_iterator::FallibleIterator; use pager::Pager; use schema::Schema; use sqlite3_parser::{ast::Cmd, lexer::sql::Parser}; -use std::sync::Arc; -use vdbe::Program; +use std::{borrow::BorrowMut, sync::Arc}; pub use types::Value; @@ -29,12 +28,28 @@ pub struct Database { impl Database { pub fn open(io: Arc, path: &str) -> Result { let pager = Arc::new(Pager::open(io.clone(), path)?); - let schema = Arc::new(Schema::new()); + let bootstrap_schema = Arc::new(Schema::new()); let conn = Connection { pager: pager.clone(), - schema: schema.clone(), + schema: bootstrap_schema.clone(), }; - conn.query("SELECT * FROM sqlite_schema")?; + let mut schema = Schema::new(); + let rows = conn.query("SELECT * FROM sqlite_schema")?; + if let Some(mut rows) = rows { + while let Some(row) = rows.next()? { + let ty = row.get::(0)?; + if ty != "table" { + continue; + } + let name: String = row.get::(1)?; + let root_page: i64 = row.get::(3)?; + let sql: String = row.get::(4)?; + let table = schema::Table::from_sql(&sql, root_page as usize)?; + assert_eq!(table.name, name); + schema.add_table(table.name.to_owned(), table); + } + } + let schema = Arc::new(schema); Ok(Database { pager, schema }) } @@ -129,8 +144,7 @@ impl Statement { Ok(Rows::new(state, self.program.clone(), self.pager.clone())) } - pub fn reset(&self) { - } + pub fn reset(&self) {} } pub struct Rows { diff --git a/core/schema.rs b/core/schema.rs index 916dab934..74f975562 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,4 +1,10 @@ +use anyhow::Result; use core::fmt; +use fallible_iterator::FallibleIterator; +use sqlite3_parser::{ + ast::{Cmd, CreateTableBody, Stmt}, + lexer::sql::Parser, +}; use std::collections::HashMap; pub struct Schema { @@ -28,6 +34,53 @@ pub struct Table { } impl Table { + pub fn from_sql(sql: &str, root_page: usize) -> Result { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser.next()?; + match cmd { + Some(cmd) => match cmd { + Cmd::Stmt(stmt) => match stmt { + Stmt::CreateTable { tbl_name, body, .. } => { + let mut cols = vec![]; + match body { + CreateTableBody::ColumnsAndConstraints { columns, .. } => { + for column in columns { + let name = column.col_name.0.to_string(); + let ty = match column.col_type { + Some(data_type) => match data_type.name.as_str() { + "INT" => Type::Integer, + "REAL" => Type::Real, + "TEXT" => Type::Text, + "BLOB" => Type::Blob, + _ => unreachable!("Unknown type: {:?}", data_type.name), + }, + None => Type::Null, + }; + cols.push(Column { name, ty }); + } + } + CreateTableBody::AsSelect(_) => todo!(), + }; + Ok(Table { + root_page, + name: tbl_name.name.to_string(), + columns: cols, + }) + } + _ => { + anyhow::bail!("Expected CREATE TABLE statement"); + } + }, + _ => { + anyhow::bail!("Expected CREATE TABLE statement"); + } + }, + None => { + anyhow::bail!("Expected CREATE TABLE statement"); + } + } + } + pub fn to_sql(&self) -> String { let mut sql = format!("CREATE TABLE {} (\n", self.name); for (i, column) in self.columns.iter().enumerate() { diff --git a/core/types.rs b/core/types.rs index 00c60064c..9f9a46516 100644 --- a/core/types.rs +++ b/core/types.rs @@ -24,6 +24,15 @@ impl FromValue for i64 { } } +impl FromValue for String { + fn from_value(value: &Value) -> Result { + match value { + Value::Text(s) => Ok(s.clone()), + _ => anyhow::bail!("Expected text value"), + } + } +} + #[derive(Debug)] pub struct Record { pub values: Vec,