From a096f2d73819c70c70e6fda2410f98c19cdcc273 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 23 Mar 2024 12:22:14 -0700 Subject: [PATCH 1/3] Checkpoint --- Cargo.lock | 93 ++++----- Cargo.toml | 3 +- src/configuration.rs | 10 +- src/main.rs | 26 +-- src/memory_backends/file_store.rs | 29 +-- src/memory_backends/mod.rs | 34 ++- src/memory_backends/postgresml/mod.rs | 131 ++++++------ src/transformer_backends/llama_cpp/mod.rs | 10 +- src/transformer_backends/mod.rs | 8 +- src/transformer_backends/openai/mod.rs | 239 ++++++++++++++++++++-- src/utils.rs | 10 +- src/worker.rs | 77 ++++--- 12 files changed, 459 insertions(+), 211 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1db173..0cfba1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.80" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" +checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" [[package]] name = "assert_cmd" @@ -131,9 +131,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" dependencies = [ "proc-macro2", "quote", @@ -149,16 +149,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atomic-write-file" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8204db279bf648d64fe845bd8840f78b39c8132ed4d6a4194c3b10d4b4cfb0b" -dependencies = [ - "nix", - "rand", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -532,6 +522,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctrlc" +version = "3.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +dependencies = [ + "nix", + "windows-sys 0.52.0", +] + [[package]] name = "darling" version = "0.14.4" @@ -1410,6 +1410,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama-cpp-2" version = "0.1.34" +source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" dependencies = [ "llama-cpp-sys-2", "thiserror", @@ -1419,6 +1420,7 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" version = "0.1.34" +source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" dependencies = [ "bindgen", "cc", @@ -1465,6 +1467,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assert_cmd", + "async-trait", "directories", "hf-hub", "ignore", @@ -1956,6 +1959,7 @@ dependencies = [ "chrono", "clap", "colored", + "ctrlc", "futures", "indicatif", "inquire", @@ -2075,9 +2079,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -2224,9 +2228,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.25" +version = "0.11.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946" +checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2" dependencies = [ "base64 0.21.7", "bytes", @@ -2769,9 +2773,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" +checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa" dependencies = [ "sqlx-core", "sqlx-macros", @@ -2782,9 +2786,9 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" +checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6" dependencies = [ "ahash", "atoi", @@ -2792,7 +2796,6 @@ dependencies = [ "bytes", "crc", "crossbeam-queue", - "dotenvy", "either", "event-listener", "futures-channel", @@ -2827,9 +2830,9 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127" dependencies = [ "proc-macro2", "quote", @@ -2840,11 +2843,10 @@ dependencies = [ [[package]] name = "sqlx-macros-core" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" +checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" dependencies = [ - "atomic-write-file", "dotenvy", "either", "heck", @@ -2867,9 +2869,9 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" +checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" dependencies = [ "atoi", "base64 0.21.7", @@ -2911,9 +2913,9 @@ dependencies = [ [[package]] name = "sqlx-postgres" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" +checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" dependencies = [ "atoi", "base64 0.21.7", @@ -2938,7 +2940,6 @@ dependencies = [ "rand", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlx-core", @@ -2952,9 +2953,9 @@ dependencies = [ [[package]] name = "sqlx-sqlite" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" dependencies = [ "atoi", "flume", @@ -3051,20 +3052,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "system-configuration" -version = "0.6.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ - "bitflags 2.4.2", + "bitflags 1.3.2", "core-foundation", "system-configuration-sys", ] [[package]] name = "system-configuration-sys" -version = "0.6.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" dependencies = [ "core-foundation-sys", "libc", @@ -3090,18 +3091,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", @@ -3631,9 +3632,9 @@ dependencies = [ [[package]] name = "whoami" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e" +checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9" dependencies = [ "redox_syscall", "wasite", diff --git a/Cargo.toml b/Cargo.toml index eb42568..13e8280 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ parking_lot = "0.12.1" once_cell = "1.19.0" directories = "5.0.1" # llama-cpp-2 = "0.1.31" -llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } +llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" } minijinja = { version = "1.0.12", features = ["loader"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing = "0.1.40" @@ -30,6 +30,7 @@ ignore = "0.4.22" pgml = { path = "submodules/postgresml/pgml-sdks/pgml" } tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } indexmap = "2.2.5" +async-trait = "0.1.78" [features] default = [] diff --git a/src/configuration.rs b/src/configuration.rs index d473385..c51f0bf 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -129,10 +129,6 @@ const fn openai_top_p_default() -> f32 { 0.95 } -const fn openai_top_k_default() -> usize { - 40 -} - const fn openai_presence_penalty() -> f32 { 0. } @@ -155,7 +151,9 @@ pub struct OpenAI { pub auth_token_env_var_name: Option, pub auth_token: Option, // The completions endpoint - pub completions_endpoint: String, + pub completions_endpoint: Option, + // The chat endpoint + pub chat_endpoint: Option, // The model name pub model: String, // Fill in the middle support @@ -168,8 +166,6 @@ pub struct OpenAI { // Other available args #[serde(default = "openai_top_p_default")] pub top_p: f32, - #[serde(default = "openai_top_k_default")] - pub top_k: usize, #[serde(default = "openai_presence_penalty")] pub presence_penalty: f32, #[serde(default = "openai_frequency_penalty")] diff --git a/src/main.rs b/src/main.rs index 9246607..2d3b5af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ fn main() -> Result<()> { .init(); let (connection, io_threads) = Connection::stdio(); - let server_capabilities = serde_json::to_value(&ServerCapabilities { + let server_capabilities = serde_json::to_value(ServerCapabilities { completion_provider: Some(CompletionOptions::default()), text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind( TextDocumentSyncKind::INCREMENTAL, @@ -77,7 +77,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let transformer_backend: Box = args.clone().try_into()?; // Set the memory_backend - let memory_backend: Arc>> = + let memory_backend: Box = Arc::new(Mutex::new(args.clone().try_into()?)); // Wrap the connection for sharing between threads @@ -87,6 +87,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let last_worker_request = Arc::new(Mutex::new(None)); // Thread local variables + // TODO: Setup some kind of handler for errors here let thread_memory_backend = memory_backend.clone(); let thread_last_worker_request = last_worker_request.clone(); let thread_connection = connection.clone(); @@ -97,7 +98,8 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { thread_last_worker_request, thread_connection, ) - .run(); + .run() + .unwrap(); }); for msg in &connection.receiver { @@ -143,13 +145,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { Message::Notification(not) => { if notification_is::(¬) { let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?; - memory_backend.lock().opened_text_document(params)?; + // memory_backend.lock().opened_text_document(params)?; } else if notification_is::(¬) { let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?; - memory_backend.lock().changed_text_document(params)?; + // memory_backend.lock().changed_text_document(params)?; } else if notification_is::(¬) { let params: RenameFilesParams = serde_json::from_value(not.params)?; - memory_backend.lock().renamed_file(params)?; + // memory_backend.lock().renamed_file(params)?; } } _ => (), @@ -170,18 +172,18 @@ mod tests { ////////////////////////////////////// ////////////////////////////////////// - #[test] - fn completion_with_default_arguments() { + #[tokio::test] + async fn completion_with_default_arguments() { let args = json!({}); let configuration = Configuration::new(args).unwrap(); let backend: Box = configuration.clone().try_into().unwrap(); let prompt = Prompt::new("".to_string(), "def fibn".to_string()); - let response = backend.do_completion(&prompt).unwrap(); + let response = backend.do_completion(&prompt).await.unwrap(); assert!(!response.insert_text.is_empty()) } - #[test] - fn completion_with_custom_gguf_model() { + #[tokio::test] + async fn completion_with_custom_gguf_model() { let args = json!({ "initializationOptions": { "memory": { @@ -232,7 +234,7 @@ mod tests { let configuration = Configuration::new(args).unwrap(); let backend: Box = configuration.clone().try_into().unwrap(); let prompt = Prompt::new("".to_string(), "def fibn".to_string()); - let response = backend.do_completion(&prompt).unwrap(); + let response = backend.do_completion(&prompt).await.unwrap(); assert!(!response.insert_text.is_empty()); } } diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 15c4619..e8d832e 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -71,7 +71,7 @@ impl FileStore { .iter() .filter(|f| **f != current_document_uri) { - let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0); + let needed = characters.saturating_sub(rope.len_chars()); if needed == 0 { break; } @@ -99,7 +99,7 @@ impl FileStore { .clone(); let cursor_index = rope.line_to_char(position.position.line as usize) + position.position.character as usize; - let start = cursor_index.checked_sub(characters / 2).unwrap_or(0); + let start = cursor_index.saturating_sub(characters / 2); let end = rope .len_chars() .min(cursor_index + (characters - (cursor_index - start))); @@ -137,15 +137,15 @@ impl FileStore { if is_chat_enabled || rope.len_chars() != cursor_index => { let max_length = tokens_to_estimated_characters(max_context_length); - let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0); + let start = cursor_index.saturating_sub(max_length / 2); let end = rope .len_chars() .min(cursor_index + (max_length - (cursor_index - start))); if is_chat_enabled { - rope.insert(cursor_index, "{CURSOR}"); + rope.insert(cursor_index, ""); let rope_slice = rope - .get_slice(start..end + "{CURSOR}".chars().count()) + .get_slice(start..end + "".chars().count()) .context("Error getting rope slice")?; rope_slice.to_string() } else { @@ -166,9 +166,8 @@ impl FileStore { } } _ => { - let start = cursor_index - .checked_sub(tokens_to_estimated_characters(max_context_length)) - .unwrap_or(0); + let start = + cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length)); let rope_slice = rope .get_slice(start..cursor_index) .context("Error getting rope slice")?; @@ -178,9 +177,13 @@ impl FileStore { } } +#[async_trait::async_trait] impl MemoryBackend for FileStore { #[instrument(skip(self))] - fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + async fn get_filter_text( + &self, + position: &TextDocumentPositionParams, + ) -> anyhow::Result { let rope = self .file_map .get(position.text_document.uri.as_str()) @@ -193,7 +196,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn build_prompt( + async fn build_prompt( &mut self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, @@ -207,7 +210,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn opened_text_document( + async fn opened_text_document( &mut self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { @@ -219,7 +222,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn changed_text_document( + async fn changed_text_document( &mut self, params: lsp_types::DidChangeTextDocumentParams, ) -> anyhow::Result<()> { @@ -246,7 +249,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { for file_rename in params.files { if let Some(rope) = self.file_map.remove(&file_rename.old_uri) { self.file_map.insert(file_rename.new_uri, rope); diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index 0931318..cb8c6b6 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -26,19 +26,29 @@ pub enum PromptForType { Generate, } +#[async_trait::async_trait] pub trait MemoryBackend { - fn init(&self) -> anyhow::Result<()> { + async fn init(&self) -> anyhow::Result<()> { Ok(()) } - fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; - fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>; - fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; - fn build_prompt( + async fn opened_text_document( + &mut self, + params: DidOpenTextDocumentParams, + ) -> anyhow::Result<()>; + async fn changed_text_document( + &mut self, + params: DidChangeTextDocumentParams, + ) -> anyhow::Result<()>; + async fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; + async fn build_prompt( &mut self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result; - fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result; + async fn get_filter_text( + &self, + position: &TextDocumentPositionParams, + ) -> anyhow::Result; } impl TryFrom for Box { @@ -55,3 +65,15 @@ impl TryFrom for Box { } } } + +// This makes testing much easier. Every transformer backend takes in a prompt. When verifying they work, its +// easier to just pass in a default prompt. +#[cfg(test)] +impl Prompt { + pub fn default_with_cursor() -> Self { + Self { + context: r#"def test_context():\n pass"#.to_string(), + code: r#"def test_code():\n "#.to_string(), + } + } +} diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index c7fce9e..da543e0 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -40,7 +40,7 @@ impl PostgresML { }; // TODO: Think on the naming of the collection // Maybe filter on metadata or I'm not sure - let collection = Collection::new("test-lsp-ai-2", Some(database_url))?; + let collection = Collection::new("test-lsp-ai-3", Some(database_url))?; // TODO: Review the pipeline let pipeline = Pipeline::new( "v1", @@ -50,7 +50,7 @@ impl PostgresML { "splitter": { "model": "recursive_character", "parameters": { - "chunk_size": 512, + "chunk_size": 1500, "chunk_overlap": 40 } }, @@ -90,7 +90,7 @@ impl PostgresML { .into_iter() .map(|path| { let text = std::fs::read_to_string(&path) - .expect(format!("Error reading path: {}", path).as_str()); + .unwrap_or_else(|_| panic!("Error reading path: {}", path)); json!({ "id": path, "text": text @@ -121,24 +121,28 @@ impl PostgresML { } } +#[async_trait::async_trait] impl MemoryBackend for PostgresML { #[instrument(skip(self))] - fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { - self.file_store.get_filter_text(position) + async fn get_filter_text( + &self, + position: &TextDocumentPositionParams, + ) -> anyhow::Result { + self.file_store.get_filter_text(position).await } #[instrument(skip(self))] - fn build_prompt( + async fn build_prompt( &mut self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result { - // This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform let query = self .file_store .get_characters_around_position(position, 512)?; - let res = self.runtime.block_on( - self.collection.vector_search( + let res = self + .collection + .vector_search( json!({ "query": { "fields": { @@ -151,8 +155,8 @@ impl MemoryBackend for PostgresML { }) .into(), &mut self.pipeline, - ), - )?; + ) + .await?; let context = res .into_iter() .map(|c| { @@ -176,7 +180,7 @@ impl MemoryBackend for PostgresML { } #[instrument(skip(self))] - fn opened_text_document( + async fn opened_text_document( &mut self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { @@ -185,68 +189,63 @@ impl MemoryBackend for PostgresML { let task_added_pipeline = self.added_pipeline; let mut task_collection = self.collection.clone(); let mut task_pipeline = self.pipeline.clone(); - self.runtime.spawn(async move { - if !task_added_pipeline { - task_collection - .add_pipeline(&mut task_pipeline) - .await - .expect("PGML - Error adding pipeline to collection"); - } + if !task_added_pipeline { + task_collection + .add_pipeline(&mut task_pipeline) + .await + .expect("PGML - Error adding pipeline to collection"); + } + task_collection + .upsert_documents( + vec![json!({ + "id": path, + "text": text + }) + .into()], + None, + ) + .await + .expect("PGML - Error upserting documents"); + self.file_store.opened_text_document(params).await + } + + #[instrument(skip(self))] + async fn changed_text_document( + &mut self, + params: lsp_types::DidChangeTextDocumentParams, + ) -> anyhow::Result<()> { + let path = params.text_document.uri.path().to_owned(); + self.debounce_tx.send(path)?; + self.file_store.changed_text_document(params).await + } + + #[instrument(skip(self))] + async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + let mut task_collection = self.collection.clone(); + let task_params = params.clone(); + for file in task_params.files { + task_collection + .delete_documents( + json!({ + "id": file.old_uri + }) + .into(), + ) + .await + .expect("PGML - Error deleting file"); + let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); task_collection .upsert_documents( vec![json!({ - "id": path, + "id": file.new_uri, "text": text }) .into()], None, ) .await - .expect("PGML - Error upserting documents"); - }); - self.file_store.opened_text_document(params) - } - - #[instrument(skip(self))] - fn changed_text_document( - &mut self, - params: lsp_types::DidChangeTextDocumentParams, - ) -> anyhow::Result<()> { - let path = params.text_document.uri.path().to_owned(); - self.debounce_tx.send(path)?; - self.file_store.changed_text_document(params) - } - - #[instrument(skip(self))] - fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { - let mut task_collection = self.collection.clone(); - let task_params = params.clone(); - self.runtime.spawn(async move { - for file in task_params.files { - task_collection - .delete_documents( - json!({ - "id": file.old_uri - }) - .into(), - ) - .await - .expect("PGML - Error deleting file"); - let text = - std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); - task_collection - .upsert_documents( - vec![json!({ - "id": file.new_uri, - "text": text - }) - .into()], - None, - ) - .await - .expect("PGML - Error adding pipeline to collection"); - } - }); - self.file_store.renamed_file(params) + .expect("PGML - Error adding pipeline to collection"); + } + self.file_store.renamed_file(params).await } } diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index ed012a6..8e865ce 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -2,7 +2,6 @@ use anyhow::Context; use hf_hub::api::sync::ApiBuilder; use tracing::{debug, instrument}; -use super::TransformerBackend; use crate::{ configuration::{self}, memory_backends::Prompt, @@ -16,6 +15,8 @@ use crate::{ mod model; use model::Model; +use super::TransformerBackend; + pub struct LlamaCPP { model: Model, configuration: configuration::ModelGGUF, @@ -62,9 +63,10 @@ impl LlamaCPP { } } +#[async_trait::async_trait] impl TransformerBackend for LlamaCPP { #[instrument(skip(self))] - fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { // let prompt = self.get_prompt_string(prompt)?; let prompt = &prompt.code; debug!("Prompt string for LLM: {}", prompt); @@ -75,7 +77,7 @@ impl TransformerBackend for LlamaCPP { } #[instrument(skip(self))] - fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { // let prompt = self.get_prompt_string(prompt)?; // debug!("Prompt string for LLM: {}", prompt); let prompt = &prompt.code; @@ -86,7 +88,7 @@ impl TransformerBackend for LlamaCPP { } #[instrument(skip(self))] - fn do_generate_stream( + async fn do_generate_stream( &self, _request: &GenerateStreamRequest, ) -> anyhow::Result { diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 38eea40..8b02357 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -9,11 +9,11 @@ use crate::{ mod llama_cpp; mod openai; +#[async_trait::async_trait] pub trait TransformerBackend { - // Should all take an enum of chat messages or just a string for completion - fn do_completion(&self, prompt: &Prompt) -> anyhow::Result; - fn do_generate(&self, prompt: &Prompt) -> anyhow::Result; - fn do_generate_stream( + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result; + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result; + async fn do_generate_stream( &self, request: &GenerateStreamRequest, ) -> anyhow::Result; diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index 2ff53c0..dfbf5bf 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -1,16 +1,22 @@ +// Something more about what this file is +// NOTE: When decoding responses from OpenAI compatbile services, we don't care about every field + +use anyhow::Context; use serde::Deserialize; -use serde_json::json; +use serde_json::{json, Value}; use tracing::instrument; -use super::TransformerBackend; use crate::{ - configuration, + configuration::{self, ChatMessage}, memory_backends::Prompt, + utils::{format_chat_messages, format_context_code}, worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, }; +use super::TransformerBackend; + pub struct OpenAI { configuration: configuration::OpenAI, } @@ -22,7 +28,19 @@ struct OpenAICompletionsChoice { #[derive(Deserialize)] struct OpenAICompletionsResponse { - choices: Vec, + choices: Option>, + error: Option, +} + +#[derive(Deserialize)] +struct OpenAIChatChoices { + message: ChatMessage, +} + +#[derive(Deserialize)] +struct OpenAIChatResponse { + choices: Option>, + error: Option, } impl OpenAI { @@ -42,7 +60,12 @@ impl OpenAI { anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); }; let res: OpenAICompletionsResponse = client - .post(&self.configuration.completions_endpoint) + .post( + self.configuration + .completions_endpoint + .as_ref() + .context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?, + ) .bearer_auth(token) .header("Content-Type", "application/json") .header("Accept", "application/json") @@ -51,7 +74,6 @@ impl OpenAI { "max_tokens": max_tokens, "n": 1, "top_p": self.configuration.top_p, - "top_k": self.configuration.top_k, "presence_penalty": self.configuration.presence_penalty, "frequency_penalty": self.configuration.frequency_penalty, "temperature": self.configuration.temperature, @@ -60,34 +82,219 @@ impl OpenAI { })) .send()? .json()?; - eprintln!("**********RECEIVED REQUEST********"); - Ok(res.choices[0].text.clone()) + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(choices) = res.choices { + Ok(choices[0].text.clone()) + } else { + anyhow::bail!("Uknown error while making request to OpenAI") + } + } + + fn get_chat(&self, messages: Vec, max_tokens: usize) -> anyhow::Result { + eprintln!( + "SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******", + messages + ); + let client = reqwest::blocking::Client::new(); + let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + std::env::var(env_var_name)? + } else if let Some(token) = &self.configuration.auth_token { + token.to_string() + } else { + anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); + }; + let res: OpenAIChatResponse = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `completions_endpoint` to use completions")?, + ) + .bearer_auth(token) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .json(&json!({ + "model": self.configuration.model, + "max_tokens": max_tokens, + "n": 1, + "top_p": self.configuration.top_p, + "presence_penalty": self.configuration.presence_penalty, + "frequency_penalty": self.configuration.frequency_penalty, + "temperature": self.configuration.temperature, + "messages": messages + })) + .send()? + .json()?; + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(choices) = res.choices { + Ok(choices[0].message.content.clone()) + } else { + anyhow::bail!("Uknown error while making request to OpenAI") + } } } +#[async_trait::async_trait] impl TransformerBackend for OpenAI { #[instrument(skip(self))] - fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); - let prompt = format!("{} \n\n {}", prompt.context, prompt.code); - let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?; + let max_tokens = self.configuration.max_tokens.completion; + let insert_text = match &self.configuration.chat { + Some(c) => match &c.completion { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, prompt); + self.get_chat(messages, max_tokens)? + } + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }, + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }; Ok(DoCompletionResponse { insert_text }) } #[instrument(skip(self))] - fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); - let prompt = format!("{} \n\n {}", prompt.context, prompt.code); - let generated_text = - self.get_completion(&prompt, self.configuration.max_tokens.completion)?; + let max_tokens = self.configuration.max_tokens.generation; + let generated_text = match &self.configuration.chat { + Some(c) => match &c.generation { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, prompt); + self.get_chat(messages, max_tokens)? + } + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }, + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }; Ok(DoGenerateResponse { generated_text }) } #[instrument(skip(self))] - fn do_generate_stream( + async fn do_generate_stream( &self, request: &GenerateStreamRequest, ) -> anyhow::Result { unimplemented!() } } + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn openai_completion_do_completion() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "completions_endpoint": "https://api.openai.com/v1/completions", + "model": "gpt-3.5-turbo-instruct", + "auth_token_env_var_name": "OPENAI_API_KEY", + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_completion(&prompt).await?; + assert!(!response.insert_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn openai_chat_do_completion() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "chat_endpoint": "https://api.openai.com/v1/chat/completions", + "model": "gpt-3.5-turbo", + "auth_token_env_var_name": "OPENAI_API_KEY", + "chat": { + "completion": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ], + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_completion(&prompt).await?; + assert!(!response.insert_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn openai_completion_do_generate() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "completions_endpoint": "https://api.openai.com/v1/completions", + "model": "gpt-3.5-turbo-instruct", + "auth_token_env_var_name": "OPENAI_API_KEY", + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_generate(&prompt).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn openai_chat_do_generate() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "config": { + "chat_endpoint": "https://api.openai.com/v1/chat/completions", + "model": "gpt-3.5-turbo", + "auth_token_env_var_name": "OPENAI_API_KEY", + "chat": { + "generation": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ] + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }}))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_generate(&prompt).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } +} diff --git a/src/utils.rs b/src/utils.rs index ce33964..ff24a26 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize { tokens * 4 } -pub fn format_chat_messages(messages: &Vec, prompt: &Prompt) -> Vec { +pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec { messages .iter() .map(|m| ChatMessage { role: m.role.to_owned(), content: m .content - .replace("{context}", &prompt.context) - .replace("{code}", &prompt.code), + .replace("{CONTEXT}", &prompt.context) + .replace("{CODE}", &prompt.code), }) .collect() } + +pub fn format_context_code(context: &str, code: &str) -> String { + format!("{context}\n\n{code}") +} diff --git a/src/worker.rs b/src/worker.rs index 87f8fd8..020a596 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -71,7 +71,7 @@ pub struct DoGenerateStreamResponse { } pub struct Worker { - transformer_backend: Box, + transformer_backend: Box, memory_backend: Arc>>, last_worker_request: Arc>>, connection: Arc, @@ -79,7 +79,7 @@ pub struct Worker { impl Worker { pub fn new( - transformer_backend: Box, + transformer_backend: Box, memory_backend: Arc>>, last_worker_request: Arc>>, connection: Arc, @@ -92,8 +92,7 @@ impl Worker { } } - #[instrument(skip(self))] - fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result { + async fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result { let prompt = self.memory_backend.lock().build_prompt( &request.params.text_document_position, PromptForType::Completion, @@ -102,7 +101,7 @@ impl Worker { .memory_backend .lock() .get_filter_text(&request.params.text_document_position)?; - let response = self.transformer_backend.do_completion(&prompt)?; + let response = self.transformer_backend.do_completion(&prompt).await?; let completion_text_edit = TextEdit::new( Range::new( Position::new( @@ -128,7 +127,7 @@ impl Worker { items: vec![item], }; let result = Some(CompletionResponse::List(completion_list)); - let result = serde_json::to_value(&result).unwrap(); + let result = serde_json::to_value(result).unwrap(); Ok(Response { id: request.id.clone(), result: Some(result), @@ -137,16 +136,16 @@ impl Worker { } #[instrument(skip(self))] - fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result { + async fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result { let prompt = self.memory_backend.lock().build_prompt( &request.params.text_document_position, PromptForType::Generate, )?; - let response = self.transformer_backend.do_generate(&prompt)?; + let response = self.transformer_backend.do_generate(&prompt).await?; let result = GenerateResult { generated_text: response.generated_text, }; - let result = serde_json::to_value(&result).unwrap(); + let result = serde_json::to_value(result).unwrap(); Ok(Response { id: request.id.clone(), result: Some(result), @@ -154,36 +153,48 @@ impl Worker { }) } - pub fn run(self) { + pub fn run(self) -> anyhow::Result<()> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build()?; loop { let option_worker_request: Option = { let mut completion_request = self.last_worker_request.lock(); std::mem::take(&mut *completion_request) }; if let Some(request) = option_worker_request { - let response = match request { - WorkerRequest::Completion(request) => match self.do_completion(&request) { - Ok(r) => r, - Err(e) => Response { - id: request.id, - result: None, - error: Some(e.to_response_error(-32603)), - }, - }, - WorkerRequest::Generate(request) => match self.do_generate(&request) { - Ok(r) => r, - Err(e) => Response { - id: request.id, - result: None, - error: Some(e.to_response_error(-32603)), - }, - }, - WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"), - }; - self.connection - .sender - .send(Message::Response(response)) - .expect("Error sending message"); + runtime.spawn(async move { + let response = match request { + WorkerRequest::Completion(request) => { + match self.do_completion(&request).await { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + } + } + WorkerRequest::Generate(request) => { + match self.do_generate(&request).await { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + } + } + WorkerRequest::GenerateStream(_) => { + panic!("Streaming is not supported yet") + } + }; + self.connection + .sender + .send(Message::Response(response)) + .expect("Error sending message"); + }); } thread::sleep(std::time::Duration::from_millis(5)); } From 2f71a4de3ea2eca746127464675e5ffe9a74701e Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 23 Mar 2024 15:44:40 -0700 Subject: [PATCH 2/3] Working overhaul --- Cargo.lock | 1 + src/main.rs | 58 +++--- src/memory_backends/file_store.rs | 45 +++-- src/memory_backends/mod.rs | 13 +- src/memory_backends/postgresml/mod.rs | 12 +- src/memory_worker.rs | 100 ++++++++++ src/transformer_backends/llama_cpp/mod.rs | 4 +- src/transformer_backends/mod.rs | 4 +- src/transformer_backends/openai/mod.rs | 4 +- src/transformer_worker.rs | 214 ++++++++++++++++++++++ src/worker.rs | 202 -------------------- submodules/postgresml | 2 +- 12 files changed, 394 insertions(+), 265 deletions(-) create mode 100644 src/memory_worker.rs create mode 100644 src/transformer_worker.rs delete mode 100644 src/worker.rs diff --git a/Cargo.lock b/Cargo.lock index 0cfba1e..d442fc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1967,6 +1967,7 @@ dependencies = [ "itertools 0.10.5", "lopdf", "md5", + "once_cell", "parking_lot", "regex", "reqwest", diff --git a/src/main.rs b/src/main.rs index 2d3b5af..2569aea 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,25 +6,31 @@ use lsp_types::{ RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, }; use parking_lot::Mutex; -use std::{sync::Arc, thread}; +use std::{ + sync::{mpsc, Arc}, + thread, +}; use tracing::error; use tracing_subscriber::{EnvFilter, FmtSubscriber}; mod configuration; mod custom_requests; mod memory_backends; +mod memory_worker; mod template; mod transformer_backends; +mod transformer_worker; mod utils; -mod worker; use configuration::Configuration; use custom_requests::generate::Generate; use memory_backends::MemoryBackend; use transformer_backends::TransformerBackend; -use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest}; +use transformer_worker::{CompletionRequest, GenerateRequest, WorkerRequest}; -use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest}; +use crate::{ + custom_requests::generate_stream::GenerateStream, transformer_worker::GenerateStreamRequest, +}; fn notification_is(notification: &Notification) -> bool { notification.method == N::METHOD @@ -71,35 +77,39 @@ fn main() -> Result<()> { // Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request // Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { - let args = Configuration::new(args)?; - - // Set the transformer_backend - let transformer_backend: Box = args.clone().try_into()?; - - // Set the memory_backend - let memory_backend: Box = - Arc::new(Mutex::new(args.clone().try_into()?)); + // Build our configuration + let configuration = Configuration::new(args)?; // Wrap the connection for sharing between threads let connection = Arc::new(connection); - // How we communicate between the worker and receiver threads + // Our channel we use to communicate with our transformer_worker let last_worker_request = Arc::new(Mutex::new(None)); + // Setup our memory_worker + // TODO: Setup some kind of error handler + // Set the memory_backend + // The channel we use to communicate with our memory_worker + let (memory_tx, memory_rx) = mpsc::channel(); + let memory_backend: Box = configuration.clone().try_into()?; + thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); + + // Setup our transformer_worker // Thread local variables // TODO: Setup some kind of handler for errors here - let thread_memory_backend = memory_backend.clone(); + // Set the transformer_backend + let transformer_backend: Box = + configuration.clone().try_into()?; let thread_last_worker_request = last_worker_request.clone(); let thread_connection = connection.clone(); + let thread_memory_tx = memory_tx.clone(); thread::spawn(move || { - Worker::new( + transformer_worker::run( transformer_backend, - thread_memory_backend, + thread_memory_tx, thread_last_worker_request, thread_connection, ) - .run() - .unwrap(); }); for msg in &connection.receiver { @@ -145,13 +155,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { Message::Notification(not) => { if notification_is::(¬) { let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?; - // memory_backend.lock().opened_text_document(params)?; + memory_tx.send(memory_worker::WorkerRequest::DidOpenTextDocument(params))?; } else if notification_is::(¬) { let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?; - // memory_backend.lock().changed_text_document(params)?; + memory_tx.send(memory_worker::WorkerRequest::DidChangeTextDocument(params))?; } else if notification_is::(¬) { let params: RenameFilesParams = serde_json::from_value(not.params)?; - // memory_backend.lock().renamed_file(params)?; + memory_tx.send(memory_worker::WorkerRequest::DidRenameFiles(params))?; } } _ => (), @@ -176,7 +186,8 @@ mod tests { async fn completion_with_default_arguments() { let args = json!({}); let configuration = Configuration::new(args).unwrap(); - let backend: Box = configuration.clone().try_into().unwrap(); + let backend: Box = + configuration.clone().try_into().unwrap(); let prompt = Prompt::new("".to_string(), "def fibn".to_string()); let response = backend.do_completion(&prompt).await.unwrap(); assert!(!response.insert_text.is_empty()) @@ -232,7 +243,8 @@ mod tests { } }); let configuration = Configuration::new(args).unwrap(); - let backend: Box = configuration.clone().try_into().unwrap(); + let backend: Box = + configuration.clone().try_into().unwrap(); let prompt = Prompt::new("".to_string(), "def fibn".to_string()); let response = backend.do_completion(&prompt).await.unwrap(); assert!(!response.insert_text.is_empty()); diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index e8d832e..6f56339 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -1,8 +1,9 @@ use anyhow::Context; use indexmap::IndexSet; use lsp_types::TextDocumentPositionParams; +use parking_lot::Mutex; use ropey::Rope; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use tracing::instrument; use crate::{ @@ -15,8 +16,8 @@ use super::{MemoryBackend, Prompt, PromptForType}; pub struct FileStore { crawl: bool, configuration: Configuration, - file_map: HashMap, - accessed_files: IndexSet, + file_map: Mutex>, + accessed_files: Mutex>, } // TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant @@ -37,8 +38,8 @@ impl FileStore { Self { crawl: file_store_config.crawl, configuration, - file_map: HashMap::new(), - accessed_files: IndexSet::new(), + file_map: Mutex::new(HashMap::new()), + accessed_files: Mutex::new(IndexSet::new()), } } @@ -46,8 +47,8 @@ impl FileStore { Self { crawl: false, configuration, - file_map: HashMap::new(), - accessed_files: IndexSet::new(), + file_map: Mutex::new(HashMap::new()), + accessed_files: Mutex::new(IndexSet::new()), } } @@ -60,6 +61,7 @@ impl FileStore { let current_document_uri = position.text_document.uri.to_string(); let mut rope = self .file_map + .lock() .get(¤t_document_uri) .context("Error file not found")? .clone(); @@ -68,6 +70,7 @@ impl FileStore { // Add to our rope if we need to for file in self .accessed_files + .lock() .iter() .filter(|f| **f != current_document_uri) { @@ -75,7 +78,8 @@ impl FileStore { if needed == 0 { break; } - let r = self.file_map.get(file).context("Error file not found")?; + let file_map = self.file_map.lock(); + let r = file_map.get(file).context("Error file not found")?; let slice_max = needed.min(r.len_chars()); let rope_str_slice = r .get_slice(0..slice_max) @@ -94,6 +98,7 @@ impl FileStore { ) -> anyhow::Result { let rope = self .file_map + .lock() .get(position.text_document.uri.as_str()) .context("Error file not found")? .clone(); @@ -186,6 +191,7 @@ impl MemoryBackend for FileStore { ) -> anyhow::Result { let rope = self .file_map + .lock() .get(position.text_document.uri.as_str()) .context("Error file not found")? .clone(); @@ -197,7 +203,7 @@ impl MemoryBackend for FileStore { #[instrument(skip(self))] async fn build_prompt( - &mut self, + &self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result { @@ -211,24 +217,24 @@ impl MemoryBackend for FileStore { #[instrument(skip(self))] async fn opened_text_document( - &mut self, + &self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { let rope = Rope::from_str(¶ms.text_document.text); let uri = params.text_document.uri.to_string(); - self.file_map.insert(uri.clone(), rope); - self.accessed_files.shift_insert(0, uri); + self.file_map.lock().insert(uri.clone(), rope); + self.accessed_files.lock().shift_insert(0, uri); Ok(()) } #[instrument(skip(self))] async fn changed_text_document( - &mut self, + &self, params: lsp_types::DidChangeTextDocumentParams, ) -> anyhow::Result<()> { let uri = params.text_document.uri.to_string(); - let rope = self - .file_map + let mut file_map = self.file_map.lock(); + let rope = file_map .get_mut(&uri) .context("Error trying to get file that does not exist")?; for change in params.content_changes { @@ -244,15 +250,16 @@ impl MemoryBackend for FileStore { *rope = Rope::from_str(&change.text); } } - self.accessed_files.shift_insert(0, uri); + self.accessed_files.lock().shift_insert(0, uri); Ok(()) } #[instrument(skip(self))] - async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + async fn renamed_file(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { for file_rename in params.files { - if let Some(rope) = self.file_map.remove(&file_rename.old_uri) { - self.file_map.insert(file_rename.new_uri, rope); + let mut file_map = self.file_map.lock(); + if let Some(rope) = file_map.remove(&file_rename.old_uri) { + file_map.insert(file_rename.new_uri, rope); } } Ok(()) diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index cb8c6b6..58e39f1 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -31,17 +31,14 @@ pub trait MemoryBackend { async fn init(&self) -> anyhow::Result<()> { Ok(()) } - async fn opened_text_document( - &mut self, - params: DidOpenTextDocumentParams, - ) -> anyhow::Result<()>; + async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; async fn changed_text_document( - &mut self, + &self, params: DidChangeTextDocumentParams, ) -> anyhow::Result<()>; - async fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; + async fn renamed_file(&self, params: RenameFilesParams) -> anyhow::Result<()>; async fn build_prompt( - &mut self, + &self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result; @@ -51,7 +48,7 @@ pub trait MemoryBackend { ) -> anyhow::Result; } -impl TryFrom for Box { +impl TryFrom for Box { type Error = anyhow::Error; fn try_from(configuration: Configuration) -> Result { diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index da543e0..f4d4791 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -133,7 +133,7 @@ impl MemoryBackend for PostgresML { #[instrument(skip(self))] async fn build_prompt( - &mut self, + &self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result { @@ -142,7 +142,7 @@ impl MemoryBackend for PostgresML { .get_characters_around_position(position, 512)?; let res = self .collection - .vector_search( + .vector_search_local( json!({ "query": { "fields": { @@ -154,7 +154,7 @@ impl MemoryBackend for PostgresML { "limit": 5 }) .into(), - &mut self.pipeline, + &self.pipeline, ) .await?; let context = res @@ -181,7 +181,7 @@ impl MemoryBackend for PostgresML { #[instrument(skip(self))] async fn opened_text_document( - &mut self, + &self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { let text = params.text_document.text.clone(); @@ -211,7 +211,7 @@ impl MemoryBackend for PostgresML { #[instrument(skip(self))] async fn changed_text_document( - &mut self, + &self, params: lsp_types::DidChangeTextDocumentParams, ) -> anyhow::Result<()> { let path = params.text_document.uri.path().to_owned(); @@ -220,7 +220,7 @@ impl MemoryBackend for PostgresML { } #[instrument(skip(self))] - async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + async fn renamed_file(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { let mut task_collection = self.collection.clone(); let task_params = params.clone(); for file in task_params.files { diff --git a/src/memory_worker.rs b/src/memory_worker.rs new file mode 100644 index 0000000..be2e13d --- /dev/null +++ b/src/memory_worker.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; + +use lsp_types::{ + DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, + TextDocumentPositionParams, +}; + +use crate::memory_backends::{MemoryBackend, Prompt, PromptForType}; + +#[derive(Debug)] +pub struct PromptRequest { + position: TextDocumentPositionParams, + prompt_for_type: PromptForType, + tx: tokio::sync::oneshot::Sender, +} + +impl PromptRequest { + pub fn new( + position: TextDocumentPositionParams, + prompt_for_type: PromptForType, + tx: tokio::sync::oneshot::Sender, + ) -> Self { + Self { + position, + prompt_for_type, + tx, + } + } +} + +#[derive(Debug)] +pub struct FilterRequest { + position: TextDocumentPositionParams, + tx: tokio::sync::oneshot::Sender, +} + +impl FilterRequest { + pub fn new( + position: TextDocumentPositionParams, + tx: tokio::sync::oneshot::Sender, + ) -> Self { + Self { position, tx } + } +} + +pub enum WorkerRequest { + FilterText(FilterRequest), + Prompt(PromptRequest), + DidOpenTextDocument(DidOpenTextDocumentParams), + DidChangeTextDocument(DidChangeTextDocumentParams), + DidRenameFiles(RenameFilesParams), +} + +pub fn run( + memory_backend: Box, + rx: std::sync::mpsc::Receiver, +) -> anyhow::Result<()> { + let memory_backend = Arc::new(memory_backend); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build()?; + loop { + let request = rx.recv()?; + let thread_memory_backend = memory_backend.clone(); + runtime.spawn(async move { + match request { + WorkerRequest::FilterText(params) => { + let filter_text = thread_memory_backend + .get_filter_text(¶ms.position) + .await + .unwrap(); + params.tx.send(filter_text).unwrap(); + } + WorkerRequest::Prompt(params) => { + let prompt = thread_memory_backend + .build_prompt(¶ms.position, params.prompt_for_type) + .await + .unwrap(); + params.tx.send(prompt).unwrap(); + } + WorkerRequest::DidOpenTextDocument(params) => { + thread_memory_backend + .opened_text_document(params) + .await + .unwrap(); + } + WorkerRequest::DidChangeTextDocument(params) => { + thread_memory_backend + .changed_text_document(params) + .await + .unwrap(); + } + WorkerRequest::DidRenameFiles(params) => { + thread_memory_backend.renamed_file(params).await.unwrap() + } + } + }); + } +} diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index 8e865ce..b410e8f 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -6,10 +6,10 @@ use crate::{ configuration::{self}, memory_backends::Prompt, template::apply_chat_template, - utils::format_chat_messages, - worker::{ + transformer_worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, + utils::format_chat_messages, }; mod model; diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 8b02357..4883cfb 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -1,7 +1,7 @@ use crate::{ configuration::{Configuration, ValidTransformerBackend}, memory_backends::Prompt, - worker::{ + transformer_worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, }; @@ -19,7 +19,7 @@ pub trait TransformerBackend { ) -> anyhow::Result; } -impl TryFrom for Box { +impl TryFrom for Box { type Error = anyhow::Error; fn try_from(configuration: Configuration) -> Result { diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index dfbf5bf..5fcc2d6 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -9,10 +9,10 @@ use tracing::instrument; use crate::{ configuration::{self, ChatMessage}, memory_backends::Prompt, - utils::{format_chat_messages, format_context_code}, - worker::{ + transformer_worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, + utils::{format_chat_messages, format_context_code}, }; use super::TransformerBackend; diff --git a/src/transformer_worker.rs b/src/transformer_worker.rs new file mode 100644 index 0000000..6357639 --- /dev/null +++ b/src/transformer_worker.rs @@ -0,0 +1,214 @@ +use lsp_server::{Connection, Message, RequestId, Response}; +use lsp_types::{ + CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, + Position, Range, TextEdit, +}; +use parking_lot::Mutex; +use std::{sync::Arc, thread}; +use tokio::sync::oneshot; + +use crate::custom_requests::generate::{GenerateParams, GenerateResult}; +use crate::custom_requests::generate_stream::GenerateStreamParams; +use crate::memory_backends::PromptForType; +use crate::memory_worker::{self, FilterRequest, PromptRequest}; +use crate::transformer_backends::TransformerBackend; +use crate::utils::ToResponseError; + +#[derive(Clone, Debug)] +pub struct CompletionRequest { + id: RequestId, + params: CompletionParams, +} + +impl CompletionRequest { + pub fn new(id: RequestId, params: CompletionParams) -> Self { + Self { id, params } + } +} + +#[derive(Clone, Debug)] +pub struct GenerateRequest { + id: RequestId, + params: GenerateParams, +} + +impl GenerateRequest { + pub fn new(id: RequestId, params: GenerateParams) -> Self { + Self { id, params } + } +} + +// The generate stream is not yet ready but we don't want to remove it +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct GenerateStreamRequest { + id: RequestId, + params: GenerateStreamParams, +} + +impl GenerateStreamRequest { + pub fn new(id: RequestId, params: GenerateStreamParams) -> Self { + Self { id, params } + } +} + +#[derive(Clone)] +pub enum WorkerRequest { + Completion(CompletionRequest), + Generate(GenerateRequest), + GenerateStream(GenerateStreamRequest), +} + +pub struct DoCompletionResponse { + pub insert_text: String, +} + +pub struct DoGenerateResponse { + pub generated_text: String, +} + +pub struct DoGenerateStreamResponse { + pub generated_text: String, +} + +pub fn run( + transformer_backend: Box, + memory_backend_tx: std::sync::mpsc::Sender, + last_worker_request: Arc>>, + connection: Arc, +) -> anyhow::Result<()> { + let transformer_backend = Arc::new(transformer_backend); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build()?; + loop { + let option_worker_request: Option = { + let mut completion_request = last_worker_request.lock(); + std::mem::take(&mut *completion_request) + }; + if let Some(request) = option_worker_request { + let thread_connection = connection.clone(); + let thread_transformer_backend = transformer_backend.clone(); + let thread_memory_backend_tx = memory_backend_tx.clone(); + runtime.spawn(async move { + let response = match request { + WorkerRequest::Completion(request) => match do_completion( + thread_transformer_backend, + thread_memory_backend_tx, + &request, + ) + .await + { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + }, + WorkerRequest::Generate(request) => match do_generate( + thread_transformer_backend, + thread_memory_backend_tx, + &request, + ) + .await + { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + }, + WorkerRequest::GenerateStream(_) => { + panic!("Streaming is not supported yet") + } + }; + thread_connection + .sender + .send(Message::Response(response)) + .expect("Error sending message"); + }); + } + thread::sleep(std::time::Duration::from_millis(5)); + } +} + +async fn do_completion( + transformer_backend: Arc>, + memory_backend_tx: std::sync::mpsc::Sender, + request: &CompletionRequest, +) -> anyhow::Result { + let (tx, rx) = oneshot::channel(); + memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new( + request.params.text_document_position.clone(), + PromptForType::Completion, + tx, + )))?; + let prompt = rx.await?; + + let (tx, rx) = oneshot::channel(); + memory_backend_tx.send(memory_worker::WorkerRequest::FilterText( + FilterRequest::new(request.params.text_document_position.clone(), tx), + ))?; + let filter_text = rx.await?; + + let response = transformer_backend.do_completion(&prompt).await?; + let completion_text_edit = TextEdit::new( + Range::new( + Position::new( + request.params.text_document_position.position.line, + request.params.text_document_position.position.character, + ), + Position::new( + request.params.text_document_position.position.line, + request.params.text_document_position.position.character, + ), + ), + response.insert_text.clone(), + ); + let item = CompletionItem { + label: format!("ai - {}", response.insert_text), + filter_text: Some(filter_text), + text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)), + kind: Some(CompletionItemKind::TEXT), + ..Default::default() + }; + let completion_list = CompletionList { + is_incomplete: false, + items: vec![item], + }; + let result = Some(CompletionResponse::List(completion_list)); + let result = serde_json::to_value(result).unwrap(); + Ok(Response { + id: request.id.clone(), + result: Some(result), + error: None, + }) +} + +async fn do_generate( + transformer_backend: Arc>, + memory_backend_tx: std::sync::mpsc::Sender, + request: &GenerateRequest, +) -> anyhow::Result { + let (tx, rx) = oneshot::channel(); + memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new( + request.params.text_document_position.clone(), + PromptForType::Completion, + tx, + )))?; + let prompt = rx.await?; + + let response = transformer_backend.do_generate(&prompt).await?; + let result = GenerateResult { + generated_text: response.generated_text, + }; + let result = serde_json::to_value(result).unwrap(); + Ok(Response { + id: request.id.clone(), + result: Some(result), + error: None, + }) +} diff --git a/src/worker.rs b/src/worker.rs deleted file mode 100644 index 020a596..0000000 --- a/src/worker.rs +++ /dev/null @@ -1,202 +0,0 @@ -use lsp_server::{Connection, Message, RequestId, Response}; -use lsp_types::{ - CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, - Position, Range, TextEdit, -}; -use parking_lot::Mutex; -use std::{sync::Arc, thread}; -use tracing::instrument; - -use crate::custom_requests::generate::{GenerateParams, GenerateResult}; -use crate::custom_requests::generate_stream::GenerateStreamParams; -use crate::memory_backends::{MemoryBackend, PromptForType}; -use crate::transformer_backends::TransformerBackend; -use crate::utils::ToResponseError; - -#[derive(Clone, Debug)] -pub struct CompletionRequest { - id: RequestId, - params: CompletionParams, -} - -impl CompletionRequest { - pub fn new(id: RequestId, params: CompletionParams) -> Self { - Self { id, params } - } -} - -#[derive(Clone, Debug)] -pub struct GenerateRequest { - id: RequestId, - params: GenerateParams, -} - -impl GenerateRequest { - pub fn new(id: RequestId, params: GenerateParams) -> Self { - Self { id, params } - } -} - -// The generate stream is not yet ready but we don't want to remove it -#[allow(dead_code)] -#[derive(Clone, Debug)] -pub struct GenerateStreamRequest { - id: RequestId, - params: GenerateStreamParams, -} - -impl GenerateStreamRequest { - pub fn new(id: RequestId, params: GenerateStreamParams) -> Self { - Self { id, params } - } -} - -#[derive(Clone)] -pub enum WorkerRequest { - Completion(CompletionRequest), - Generate(GenerateRequest), - GenerateStream(GenerateStreamRequest), -} - -pub struct DoCompletionResponse { - pub insert_text: String, -} - -pub struct DoGenerateResponse { - pub generated_text: String, -} - -pub struct DoGenerateStreamResponse { - pub generated_text: String, -} - -pub struct Worker { - transformer_backend: Box, - memory_backend: Arc>>, - last_worker_request: Arc>>, - connection: Arc, -} - -impl Worker { - pub fn new( - transformer_backend: Box, - memory_backend: Arc>>, - last_worker_request: Arc>>, - connection: Arc, - ) -> Self { - Self { - transformer_backend, - memory_backend, - last_worker_request, - connection, - } - } - - async fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result { - let prompt = self.memory_backend.lock().build_prompt( - &request.params.text_document_position, - PromptForType::Completion, - )?; - let filter_text = self - .memory_backend - .lock() - .get_filter_text(&request.params.text_document_position)?; - let response = self.transformer_backend.do_completion(&prompt).await?; - let completion_text_edit = TextEdit::new( - Range::new( - Position::new( - request.params.text_document_position.position.line, - request.params.text_document_position.position.character, - ), - Position::new( - request.params.text_document_position.position.line, - request.params.text_document_position.position.character, - ), - ), - response.insert_text.clone(), - ); - let item = CompletionItem { - label: format!("ai - {}", response.insert_text), - filter_text: Some(filter_text), - text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)), - kind: Some(CompletionItemKind::TEXT), - ..Default::default() - }; - let completion_list = CompletionList { - is_incomplete: false, - items: vec![item], - }; - let result = Some(CompletionResponse::List(completion_list)); - let result = serde_json::to_value(result).unwrap(); - Ok(Response { - id: request.id.clone(), - result: Some(result), - error: None, - }) - } - - #[instrument(skip(self))] - async fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result { - let prompt = self.memory_backend.lock().build_prompt( - &request.params.text_document_position, - PromptForType::Generate, - )?; - let response = self.transformer_backend.do_generate(&prompt).await?; - let result = GenerateResult { - generated_text: response.generated_text, - }; - let result = serde_json::to_value(result).unwrap(); - Ok(Response { - id: request.id.clone(), - result: Some(result), - error: None, - }) - } - - pub fn run(self) -> anyhow::Result<()> { - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build()?; - loop { - let option_worker_request: Option = { - let mut completion_request = self.last_worker_request.lock(); - std::mem::take(&mut *completion_request) - }; - if let Some(request) = option_worker_request { - runtime.spawn(async move { - let response = match request { - WorkerRequest::Completion(request) => { - match self.do_completion(&request).await { - Ok(r) => r, - Err(e) => Response { - id: request.id, - result: None, - error: Some(e.to_response_error(-32603)), - }, - } - } - WorkerRequest::Generate(request) => { - match self.do_generate(&request).await { - Ok(r) => r, - Err(e) => Response { - id: request.id, - result: None, - error: Some(e.to_response_error(-32603)), - }, - } - } - WorkerRequest::GenerateStream(_) => { - panic!("Streaming is not supported yet") - } - }; - self.connection - .sender - .send(Message::Response(response)) - .expect("Error sending message"); - }); - } - thread::sleep(std::time::Duration::from_millis(5)); - } - } -} diff --git a/submodules/postgresml b/submodules/postgresml index 0842673..a16ff70 160000 --- a/submodules/postgresml +++ b/submodules/postgresml @@ -1 +1 @@ -Subproject commit 0842673804b0af18afbb24cd0e4fc00b01fff1fa +Subproject commit a16ff700c1d54582711d14c2f57341bb9bd82be2 From fa8e19c1ce3d95c1f2a65e40b4c164da348d7557 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 23 Mar 2024 19:01:05 -0700 Subject: [PATCH 3/3] Overhaul done --- src/configuration.rs | 31 +++- src/main.rs | 4 - src/memory_backends/mod.rs | 7 + src/transformer_backends/anthropic.rs | 207 +++++++++++++++++++++++++ src/transformer_backends/mod.rs | 4 + src/transformer_backends/openai/mod.rs | 128 ++++++++------- 6 files changed, 316 insertions(+), 65 deletions(-) create mode 100644 src/transformer_backends/anthropic.rs diff --git a/src/configuration.rs b/src/configuration.rs index c51f0bf..39f076b 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -19,6 +19,7 @@ pub enum ValidMemoryBackend { pub enum ValidTransformerBackend { LlamaCPP(ModelGGUF), OpenAI(OpenAI), + Anthropic(Anthropic), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -36,6 +37,7 @@ pub struct Chat { } #[derive(Clone, Debug, Deserialize)] +#[allow(clippy::upper_case_acronyms)] pub struct FIM { pub start: String, pub middle: String, @@ -145,7 +147,7 @@ const fn openai_max_context() -> usize { DEFAULT_OPENAI_MAX_CONTEXT } -#[derive(Clone, Debug, Default, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OpenAI { // The auth token env var name pub auth_token_env_var_name: Option, @@ -176,9 +178,35 @@ pub struct OpenAI { max_context: usize, } +#[derive(Clone, Debug, Deserialize)] +pub struct Anthropic { + // The auth token env var name + pub auth_token_env_var_name: Option, + pub auth_token: Option, + // The completions endpoint + pub completions_endpoint: Option, + // The chat endpoint + pub chat_endpoint: Option, + // The model name + pub model: String, + // Fill in the middle support + pub fim: Option, + // The maximum number of new tokens to generate + #[serde(default)] + pub max_tokens: MaxTokens, + // Chat args + pub chat: Chat, + // System prompt + #[serde(default = "openai_top_p_default")] + pub top_p: f32, + #[serde(default = "openai_temperature")] + pub temperature: f32, +} + #[derive(Clone, Debug, Deserialize)] struct ValidTransformerConfiguration { openai: Option, + anthropic: Option, model_gguf: Option, } @@ -186,6 +214,7 @@ impl Default for ValidTransformerConfiguration { fn default() -> Self { Self { model_gguf: Some(ModelGGUF::default()), + anthropic: None, openai: None, } } diff --git a/src/main.rs b/src/main.rs index 2569aea..c32a324 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,10 +72,6 @@ fn main() -> Result<()> { Ok(()) } -// This main loop is tricky -// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get -// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request -// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { // Build our configuration let configuration = Configuration::new(args)?; diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index 58e39f1..1a0832d 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -73,4 +73,11 @@ impl Prompt { code: r#"def test_code():\n "#.to_string(), } } + + pub fn default_without_cursor() -> Self { + Self { + context: r#"def test_context():\n pass"#.to_string(), + code: r#"def test_code():\n "#.to_string(), + } + } } diff --git a/src/transformer_backends/anthropic.rs b/src/transformer_backends/anthropic.rs new file mode 100644 index 0000000..9d2bb9c --- /dev/null +++ b/src/transformer_backends/anthropic.rs @@ -0,0 +1,207 @@ +use anyhow::Context; +use serde::Deserialize; +use serde_json::{json, Value}; +use tracing::instrument; + +use crate::{ + configuration::{self, ChatMessage}, + memory_backends::Prompt, + transformer_worker::{ + DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, + }, + utils::format_chat_messages, +}; + +use super::TransformerBackend; + +pub struct Anthropic { + configuration: configuration::Anthropic, +} + +#[derive(Deserialize)] +struct AnthropicChatMessage { + text: String, +} + +#[derive(Deserialize)] +struct AnthropicChatResponse { + content: Option>, + error: Option, +} + +impl Anthropic { + #[instrument] + pub fn new(configuration: configuration::Anthropic) -> Self { + Self { configuration } + } + + async fn get_chat( + &self, + system_prompt: String, + messages: Vec, + max_tokens: usize, + ) -> anyhow::Result { + eprintln!( + "SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******", + messages + ); + let client = reqwest::Client::new(); + let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + std::env::var(env_var_name)? + } else if let Some(token) = &self.configuration.auth_token { + token.to_string() + } else { + anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); + }; + let res: AnthropicChatResponse = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `completions_endpoint` to use completions")?, + ) + .header("x-api-key", token) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .json(&json!({ + "model": self.configuration.model, + "system": system_prompt, + "max_tokens": max_tokens, + "top_p": self.configuration.top_p, + "temperature": self.configuration.temperature, + "messages": messages + })) + .send() + .await? + .json() + .await?; + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(mut content) = res.content { + Ok(std::mem::take(&mut content[0].text)) + } else { + anyhow::bail!("Uknown error while making request to OpenAI") + } + } + + async fn do_get_chat( + &self, + prompt: &Prompt, + messages: &[ChatMessage], + max_tokens: usize, + ) -> anyhow::Result { + let mut messages = format_chat_messages(messages, prompt); + if messages[0].role != "system" { + anyhow::bail!( + "When using Anthropic, the first message in chat must have role = `system`" + ) + } + let system_prompt = messages.remove(0).content; + self.get_chat(system_prompt, messages, max_tokens).await + } +} + +#[async_trait::async_trait] +impl TransformerBackend for Anthropic { + #[instrument(skip(self))] + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + eprintln!("--------------{:?}---------------", prompt); + let max_tokens = self.configuration.max_tokens.completion; + let insert_text = match &self.configuration.chat.completion { + Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?, + None => { + anyhow::bail!("Please provide `anthropic->chat->completion` messages") + } + }; + Ok(DoCompletionResponse { insert_text }) + } + + #[instrument(skip(self))] + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + eprintln!("--------------{:?}---------------", prompt); + let max_tokens = self.configuration.max_tokens.generation; + let generated_text = match &self.configuration.chat.generation { + Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?, + None => { + anyhow::bail!("Please provide `anthropic->chat->generation` messages") + } + }; + Ok(DoGenerateResponse { generated_text }) + } + + #[instrument(skip(self))] + async fn do_generate_stream( + &self, + request: &GenerateStreamRequest, + ) -> anyhow::Result { + unimplemented!() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn anthropic_chat_do_completion() -> anyhow::Result<()> { + let configuration: configuration::Anthropic = serde_json::from_value(json!({ + "chat_endpoint": "https://api.anthropic.com/v1/messages", + "model": "claude-3-haiku-20240307", + "auth_token_env_var_name": "ANTHROPIC_API_KEY", + "chat": { + "completion": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ], + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let anthropic = Anthropic::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = anthropic.do_completion(&prompt).await?; + assert!(!response.insert_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn anthropic_chat_do_generate() -> anyhow::Result<()> { + let configuration: configuration::Anthropic = serde_json::from_value(json!({ + "chat_endpoint": "https://api.anthropic.com/v1/messages", + "model": "claude-3-haiku-20240307", + "auth_token_env_var_name": "ANTHROPIC_API_KEY", + "chat": { + "generation": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ] + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let anthropic = Anthropic::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = anthropic.do_generate(&prompt).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } +} diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 4883cfb..bbf5386 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -6,6 +6,7 @@ use crate::{ }, }; +mod anthropic; mod llama_cpp; mod openai; @@ -30,6 +31,9 @@ impl TryFrom for Box { ValidTransformerBackend::OpenAI(openai_config) => { Ok(Box::new(openai::OpenAI::new(openai_config))) } + ValidTransformerBackend::Anthropic(anthropic_config) => { + Ok(Box::new(anthropic::Anthropic::new(anthropic_config))) + } } } } diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index 5fcc2d6..29c75ee 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -49,22 +49,26 @@ impl OpenAI { Self { configuration } } - fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result { - eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt); - let client = reqwest::blocking::Client::new(); - let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { - std::env::var(env_var_name)? + fn get_token(&self) -> anyhow::Result { + if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + Ok(std::env::var(env_var_name)?) } else if let Some(token) = &self.configuration.auth_token { - token.to_string() + Ok(token.to_string()) } else { - anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); - }; + anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API") + } + } + + async fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result { + eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt); + let client = reqwest::Client::new(); + let token = self.get_token()?; let res: OpenAICompletionsResponse = client .post( self.configuration .completions_endpoint .as_ref() - .context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?, + .context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?, ) .bearer_auth(token) .header("Content-Type", "application/json") @@ -80,30 +84,28 @@ impl OpenAI { "echo": false, "prompt": prompt })) - .send()? - .json()?; + .send().await? + .json().await?; if let Some(error) = res.error { anyhow::bail!("{:?}", error.to_string()) - } else if let Some(choices) = res.choices { - Ok(choices[0].text.clone()) + } else if let Some(mut choices) = res.choices { + Ok(std::mem::take(&mut choices[0].text)) } else { anyhow::bail!("Uknown error while making request to OpenAI") } } - fn get_chat(&self, messages: Vec, max_tokens: usize) -> anyhow::Result { + async fn get_chat( + &self, + messages: Vec, + max_tokens: usize, + ) -> anyhow::Result { eprintln!( "SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******", messages ); - let client = reqwest::blocking::Client::new(); - let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { - std::env::var(env_var_name)? - } else if let Some(token) = &self.configuration.auth_token { - token.to_string() - } else { - anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); - }; + let client = reqwest::Client::new(); + let token = self.get_token()?; let res: OpenAIChatResponse = client .post( self.configuration @@ -124,8 +126,10 @@ impl OpenAI { "temperature": self.configuration.temperature, "messages": messages })) - .send()? - .json()?; + .send() + .await? + .json() + .await?; if let Some(error) = res.error { anyhow::bail!("{:?}", error.to_string()) } else if let Some(choices) = res.choices { @@ -134,6 +138,27 @@ impl OpenAI { anyhow::bail!("Uknown error while making request to OpenAI") } } + + async fn do_chat_completion( + &self, + prompt: &Prompt, + messages: Option<&Vec>, + max_tokens: usize, + ) -> anyhow::Result { + match messages { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, prompt); + self.get_chat(messages, max_tokens).await + } + None => { + self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + ) + .await + } + } + } } #[async_trait::async_trait] @@ -142,22 +167,14 @@ impl TransformerBackend for OpenAI { async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); let max_tokens = self.configuration.max_tokens.completion; - let insert_text = match &self.configuration.chat { - Some(c) => match &c.completion { - Some(completion_messages) => { - let messages = format_chat_messages(completion_messages, prompt); - self.get_chat(messages, max_tokens)? - } - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }, - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }; + let messages = self + .configuration + .chat + .as_ref() + .and_then(|c| c.completion.as_ref()); + let insert_text = self + .do_chat_completion(prompt, messages, max_tokens) + .await?; Ok(DoCompletionResponse { insert_text }) } @@ -165,22 +182,14 @@ impl TransformerBackend for OpenAI { async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); let max_tokens = self.configuration.max_tokens.generation; - let generated_text = match &self.configuration.chat { - Some(c) => match &c.generation { - Some(completion_messages) => { - let messages = format_chat_messages(completion_messages, prompt); - self.get_chat(messages, max_tokens)? - } - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }, - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }; + let messages = self + .configuration + .chat + .as_ref() + .and_then(|c| c.generation.as_ref()); + let generated_text = self + .do_chat_completion(prompt, messages, max_tokens) + .await?; Ok(DoGenerateResponse { generated_text }) } @@ -210,7 +219,7 @@ mod test { "max_context": 4096 }))?; let openai = OpenAI::new(configuration); - let prompt = Prompt::default_with_cursor(); + let prompt = Prompt::default_without_cursor(); let response = openai.do_completion(&prompt).await?; assert!(!response.insert_text.is_empty()); Ok(()) @@ -260,7 +269,7 @@ mod test { "max_context": 4096 }))?; let openai = OpenAI::new(configuration); - let prompt = Prompt::default_with_cursor(); + let prompt = Prompt::default_without_cursor(); let response = openai.do_generate(&prompt).await?; assert!(!response.generated_text.is_empty()); Ok(()) @@ -269,7 +278,6 @@ mod test { #[tokio::test] async fn openai_chat_do_generate() -> anyhow::Result<()> { let configuration: configuration::OpenAI = serde_json::from_value(json!({ - "config": { "chat_endpoint": "https://api.openai.com/v1/chat/completions", "model": "gpt-3.5-turbo", "auth_token_env_var_name": "OPENAI_API_KEY", @@ -290,7 +298,7 @@ mod test { "generation": 64 }, "max_context": 4096 - }}))?; + }))?; let openai = OpenAI::new(configuration); let prompt = Prompt::default_with_cursor(); let response = openai.do_generate(&prompt).await?;