From 217933c0c718dcd63431908d072fac282e728a10 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:27:25 -0700 Subject: [PATCH] Started the work for crawling and added better code grabbing for the FileStore --- Cargo.lock | 1 + Cargo.toml | 1 + src/configuration.rs | 17 ++++-- src/memory_backends/file_store.rs | 84 ++++++++++++++++++++++----- src/memory_backends/mod.rs | 6 +- src/memory_backends/postgresml/mod.rs | 29 ++++----- 6 files changed, 103 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 020f9a1..e1db173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1468,6 +1468,7 @@ dependencies = [ "directories", "hf-hub", "ignore", + "indexmap", "llama-cpp-2", "lsp-server", "lsp-types", diff --git a/Cargo.toml b/Cargo.toml index 9fdccc3..eb42568 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ reqwest = { version = "0.11.25", features = ["blocking", "json"] } ignore = "0.4.22" pgml = { path = "submodules/postgresml/pgml-sdks/pgml" } tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } +indexmap = "2.2.5" [features] default = [] diff --git a/src/configuration.rs b/src/configuration.rs index 342f306..d473385 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -12,7 +12,7 @@ const DEFAULT_MAX_GENERATION_TOKENS: usize = 256; pub type Kwargs = HashMap; pub enum ValidMemoryBackend { - FileStore, + FileStore(FileStore), PostgresML(PostgresML), } @@ -60,18 +60,24 @@ impl Default for MaxTokens { #[derive(Clone, Debug, Deserialize)] pub struct PostgresML { pub database_url: Option, + pub crawl: bool, +} + +#[derive(Clone, Debug, Deserialize, Default)] +pub struct FileStore { + pub crawl: bool, } #[derive(Clone, Debug, Deserialize)] struct ValidMemoryConfiguration { - file_store: Option, + file_store: Option, postgresml: Option, } impl Default for ValidMemoryConfiguration { fn default() -> Self { Self { - file_store: Some(json!({})), + file_store: Some(FileStore::default()), postgresml: None, } } @@ -227,8 +233,9 @@ impl Configuration { } pub fn get_memory_backend(&self) -> Result { - if self.valid_config.memory.file_store.is_some() { - Ok(ValidMemoryBackend::FileStore) + // if self.valid_config.memory.file_store.is_some() { + if let Some(file_store) = &self.valid_config.memory.file_store { + Ok(ValidMemoryBackend::FileStore(file_store.to_owned())) } else if let Some(postgresml) = &self.valid_config.memory.postgresml { Ok(ValidMemoryBackend::PostgresML(postgresml.to_owned())) } else { diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index ac924dc..15c4619 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -1,30 +1,92 @@ use anyhow::Context; +use indexmap::IndexSet; use lsp_types::TextDocumentPositionParams; use ropey::Rope; use std::collections::HashMap; use tracing::instrument; -use crate::{configuration::Configuration, utils::tokens_to_estimated_characters}; +use crate::{ + configuration::{self, Configuration}, + utils::tokens_to_estimated_characters, +}; use super::{MemoryBackend, Prompt, PromptForType}; pub struct FileStore { + crawl: bool, configuration: Configuration, file_map: HashMap, + accessed_files: IndexSet, } // TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant // files and then when asked for context it loads them in by the most recently accessed? That seems kind of silly honestly, but I could see // how users who want to use models with massive context lengths would just want their entire project as context for generation tasks // I'm not sure yet, this is something I need to think through more + +// Ok here are some more ideas +// We take a crawl arg which is a bool of true or false for file_store +// If true we crawl until we get to the max_context_length and then we stop crawling +// We keep track of the last opened / changed files, and prioritize those when building the context for our llms + +// For memory backends like PostgresML, they will need to take some kind of max_context_length to crawl or something. +// In other words, there needs to be some specification for how much they should be crawling because the limiting happens in the vector_recall impl FileStore { - pub fn new(configuration: Configuration) -> Self { + pub fn new(file_store_config: configuration::FileStore, configuration: Configuration) -> Self { + // TODO: maybe crawl Self { + crawl: file_store_config.crawl, configuration, file_map: HashMap::new(), + accessed_files: IndexSet::new(), } } + pub fn new_without_crawl(configuration: Configuration) -> Self { + Self { + crawl: false, + configuration, + file_map: HashMap::new(), + accessed_files: IndexSet::new(), + } + } + + fn get_rope_for_position( + &self, + position: &TextDocumentPositionParams, + characters: usize, + ) -> anyhow::Result<(Rope, usize)> { + // Get the rope and set our initial cursor index + let current_document_uri = position.text_document.uri.to_string(); + let mut rope = self + .file_map + .get(¤t_document_uri) + .context("Error file not found")? + .clone(); + let mut cursor_index = rope.line_to_char(position.position.line as usize) + + position.position.character as usize; + // Add to our rope if we need to + for file in self + .accessed_files + .iter() + .filter(|f| **f != current_document_uri) + { + let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0); + if needed == 0 { + break; + } + let r = self.file_map.get(file).context("Error file not found")?; + let slice_max = needed.min(r.len_chars()); + let rope_str_slice = r + .get_slice(0..slice_max) + .context("Error getting slice")? + .to_string(); + rope.insert(0, &rope_str_slice); + cursor_index += slice_max; + } + Ok((rope, cursor_index)) + } + pub fn get_characters_around_position( &self, position: &TextDocumentPositionParams, @@ -53,14 +115,7 @@ impl FileStore { prompt_for_type: PromptForType, max_context_length: usize, ) -> anyhow::Result { - let mut rope = self - .file_map - .get(position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); - - let cursor_index = rope.line_to_char(position.position.line as usize) - + position.position.character as usize; + let (mut rope, cursor_index) = self.get_rope_for_position(position, max_context_length)?; let is_chat_enabled = match prompt_for_type { PromptForType::Completion => self @@ -157,8 +212,9 @@ impl MemoryBackend for FileStore { params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { let rope = Rope::from_str(¶ms.text_document.text); - self.file_map - .insert(params.text_document.uri.to_string(), rope); + let uri = params.text_document.uri.to_string(); + self.file_map.insert(uri.clone(), rope); + self.accessed_files.shift_insert(0, uri); Ok(()) } @@ -167,9 +223,10 @@ impl MemoryBackend for FileStore { &mut self, params: lsp_types::DidChangeTextDocumentParams, ) -> anyhow::Result<()> { + let uri = params.text_document.uri.to_string(); let rope = self .file_map - .get_mut(params.text_document.uri.as_str()) + .get_mut(&uri) .context("Error trying to get file that does not exist")?; for change in params.content_changes { // If range is ommitted, text is the new text of the document @@ -184,6 +241,7 @@ impl MemoryBackend for FileStore { *rope = Rope::from_str(&change.text); } } + self.accessed_files.shift_insert(0, uri); Ok(()) } diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index ee4dc27..0931318 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -46,9 +46,9 @@ impl TryFrom for Box { fn try_from(configuration: Configuration) -> Result { match configuration.get_memory_backend()? { - ValidMemoryBackend::FileStore => { - Ok(Box::new(file_store::FileStore::new(configuration))) - } + ValidMemoryBackend::FileStore(file_store_config) => Ok(Box::new( + file_store::FileStore::new(file_store_config, configuration), + )), ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new( postgresml::PostgresML::new(postgresml_config, configuration)?, )), diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index aecd301..c7fce9e 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -1,5 +1,5 @@ use std::{ - sync::mpsc::{self, Sender, TryRecvError}, + sync::mpsc::{self, Sender}, time::Duration, }; @@ -24,6 +24,7 @@ pub struct PostgresML { pipeline: Pipeline, runtime: Runtime, debounce_tx: Sender, + added_pipeline: bool, } impl PostgresML { @@ -31,7 +32,7 @@ impl PostgresML { postgresml_config: configuration::PostgresML, configuration: Configuration, ) -> anyhow::Result { - let file_store = FileStore::new(configuration.clone()); + let file_store = FileStore::new_without_crawl(configuration.clone()); let database_url = if let Some(database_url) = postgresml_config.database_url { database_url } else { @@ -39,7 +40,7 @@ impl PostgresML { }; // TODO: Think on the naming of the collection // Maybe filter on metadata or I'm not sure - let collection = Collection::new("test-lsp-ai", Some(database_url))?; + let collection = Collection::new("test-lsp-ai-2", Some(database_url))?; // TODO: Review the pipeline let pipeline = Pipeline::new( "v1", @@ -66,15 +67,6 @@ impl PostgresML { .worker_threads(2) .enable_all() .build()?; - // Add the collection to the pipeline - let mut task_collection = collection.clone(); - let mut task_pipeline = pipeline.clone(); - runtime.spawn(async move { - task_collection - .add_pipeline(&mut task_pipeline) - .await - .expect("PGML - Error adding pipeline to collection"); - }); // Setup up a debouncer for changed text documents let mut task_collection = collection.clone(); let (debounce_tx, debounce_rx) = mpsc::channel::(); @@ -124,6 +116,7 @@ impl PostgresML { pipeline, runtime, debounce_tx, + added_pipeline: false, }) } } @@ -140,7 +133,7 @@ impl MemoryBackend for PostgresML { position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result { - // This is blocking, but this is ok as we only query for it from the worker when we are actually doing a transform + // This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform let query = self .file_store .get_characters_around_position(position, 512)?; @@ -189,8 +182,16 @@ impl MemoryBackend for PostgresML { ) -> anyhow::Result<()> { let text = params.text_document.text.clone(); let path = params.text_document.uri.path().to_owned(); + let task_added_pipeline = self.added_pipeline; let mut task_collection = self.collection.clone(); + let mut task_pipeline = self.pipeline.clone(); self.runtime.spawn(async move { + if !task_added_pipeline { + task_collection + .add_pipeline(&mut task_pipeline) + .await + .expect("PGML - Error adding pipeline to collection"); + } task_collection .upsert_documents( vec![json!({ @@ -201,7 +202,7 @@ impl MemoryBackend for PostgresML { None, ) .await - .expect("PGML - Error adding pipeline to collection"); + .expect("PGML - Error upserting documents"); }); self.file_store.opened_text_document(params) }