diff --git a/Cargo.lock b/Cargo.lock index e1db173..0cfba1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.80" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" +checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" [[package]] name = "assert_cmd" @@ -131,9 +131,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" dependencies = [ "proc-macro2", "quote", @@ -149,16 +149,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atomic-write-file" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8204db279bf648d64fe845bd8840f78b39c8132ed4d6a4194c3b10d4b4cfb0b" -dependencies = [ - "nix", - "rand", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -532,6 +522,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctrlc" +version = "3.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +dependencies = [ + "nix", + "windows-sys 0.52.0", +] + [[package]] name = "darling" version = "0.14.4" @@ -1410,6 +1410,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama-cpp-2" version = "0.1.34" +source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" dependencies = [ "llama-cpp-sys-2", "thiserror", @@ -1419,6 +1420,7 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" version = "0.1.34" +source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" dependencies = [ "bindgen", "cc", @@ -1465,6 +1467,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assert_cmd", + "async-trait", "directories", "hf-hub", "ignore", @@ -1956,6 +1959,7 @@ dependencies = [ "chrono", "clap", "colored", + "ctrlc", "futures", "indicatif", "inquire", @@ -2075,9 +2079,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -2224,9 +2228,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.25" +version = "0.11.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946" +checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2" dependencies = [ "base64 0.21.7", "bytes", @@ -2769,9 +2773,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" +checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa" dependencies = [ "sqlx-core", "sqlx-macros", @@ -2782,9 +2786,9 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" +checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6" dependencies = [ "ahash", "atoi", @@ -2792,7 +2796,6 @@ dependencies = [ "bytes", "crc", "crossbeam-queue", - "dotenvy", "either", "event-listener", "futures-channel", @@ -2827,9 +2830,9 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127" dependencies = [ "proc-macro2", "quote", @@ -2840,11 +2843,10 @@ dependencies = [ [[package]] name = "sqlx-macros-core" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" +checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" dependencies = [ - "atomic-write-file", "dotenvy", "either", "heck", @@ -2867,9 +2869,9 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" +checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" dependencies = [ "atoi", "base64 0.21.7", @@ -2911,9 +2913,9 @@ dependencies = [ [[package]] name = "sqlx-postgres" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" +checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" dependencies = [ "atoi", "base64 0.21.7", @@ -2938,7 +2940,6 @@ dependencies = [ "rand", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlx-core", @@ -2952,9 +2953,9 @@ dependencies = [ [[package]] name = "sqlx-sqlite" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" dependencies = [ "atoi", "flume", @@ -3051,20 +3052,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "system-configuration" -version = "0.6.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ - "bitflags 2.4.2", + "bitflags 1.3.2", "core-foundation", "system-configuration-sys", ] [[package]] name = "system-configuration-sys" -version = "0.6.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" dependencies = [ "core-foundation-sys", "libc", @@ -3090,18 +3091,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", @@ -3631,9 +3632,9 @@ dependencies = [ [[package]] name = "whoami" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e" +checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9" dependencies = [ "redox_syscall", "wasite", diff --git a/Cargo.toml b/Cargo.toml index eb42568..13e8280 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ parking_lot = "0.12.1" once_cell = "1.19.0" directories = "5.0.1" # llama-cpp-2 = "0.1.31" -llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } +llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" } minijinja = { version = "1.0.12", features = ["loader"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing = "0.1.40" @@ -30,6 +30,7 @@ ignore = "0.4.22" pgml = { path = "submodules/postgresml/pgml-sdks/pgml" } tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } indexmap = "2.2.5" +async-trait = "0.1.78" [features] default = [] diff --git a/src/configuration.rs b/src/configuration.rs index d473385..c51f0bf 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -129,10 +129,6 @@ const fn openai_top_p_default() -> f32 { 0.95 } -const fn openai_top_k_default() -> usize { - 40 -} - const fn openai_presence_penalty() -> f32 { 0. } @@ -155,7 +151,9 @@ pub struct OpenAI { pub auth_token_env_var_name: Option, pub auth_token: Option, // The completions endpoint - pub completions_endpoint: String, + pub completions_endpoint: Option, + // The chat endpoint + pub chat_endpoint: Option, // The model name pub model: String, // Fill in the middle support @@ -168,8 +166,6 @@ pub struct OpenAI { // Other available args #[serde(default = "openai_top_p_default")] pub top_p: f32, - #[serde(default = "openai_top_k_default")] - pub top_k: usize, #[serde(default = "openai_presence_penalty")] pub presence_penalty: f32, #[serde(default = "openai_frequency_penalty")] diff --git a/src/main.rs b/src/main.rs index 9246607..2d3b5af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ fn main() -> Result<()> { .init(); let (connection, io_threads) = Connection::stdio(); - let server_capabilities = serde_json::to_value(&ServerCapabilities { + let server_capabilities = serde_json::to_value(ServerCapabilities { completion_provider: Some(CompletionOptions::default()), text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind( TextDocumentSyncKind::INCREMENTAL, @@ -77,7 +77,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let transformer_backend: Box = args.clone().try_into()?; // Set the memory_backend - let memory_backend: Arc>> = + let memory_backend: Box = Arc::new(Mutex::new(args.clone().try_into()?)); // Wrap the connection for sharing between threads @@ -87,6 +87,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let last_worker_request = Arc::new(Mutex::new(None)); // Thread local variables + // TODO: Setup some kind of handler for errors here let thread_memory_backend = memory_backend.clone(); let thread_last_worker_request = last_worker_request.clone(); let thread_connection = connection.clone(); @@ -97,7 +98,8 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { thread_last_worker_request, thread_connection, ) - .run(); + .run() + .unwrap(); }); for msg in &connection.receiver { @@ -143,13 +145,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { Message::Notification(not) => { if notification_is::(¬) { let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?; - memory_backend.lock().opened_text_document(params)?; + // memory_backend.lock().opened_text_document(params)?; } else if notification_is::(¬) { let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?; - memory_backend.lock().changed_text_document(params)?; + // memory_backend.lock().changed_text_document(params)?; } else if notification_is::(¬) { let params: RenameFilesParams = serde_json::from_value(not.params)?; - memory_backend.lock().renamed_file(params)?; + // memory_backend.lock().renamed_file(params)?; } } _ => (), @@ -170,18 +172,18 @@ mod tests { ////////////////////////////////////// ////////////////////////////////////// - #[test] - fn completion_with_default_arguments() { + #[tokio::test] + async fn completion_with_default_arguments() { let args = json!({}); let configuration = Configuration::new(args).unwrap(); let backend: Box = configuration.clone().try_into().unwrap(); let prompt = Prompt::new("".to_string(), "def fibn".to_string()); - let response = backend.do_completion(&prompt).unwrap(); + let response = backend.do_completion(&prompt).await.unwrap(); assert!(!response.insert_text.is_empty()) } - #[test] - fn completion_with_custom_gguf_model() { + #[tokio::test] + async fn completion_with_custom_gguf_model() { let args = json!({ "initializationOptions": { "memory": { @@ -232,7 +234,7 @@ mod tests { let configuration = Configuration::new(args).unwrap(); let backend: Box = configuration.clone().try_into().unwrap(); let prompt = Prompt::new("".to_string(), "def fibn".to_string()); - let response = backend.do_completion(&prompt).unwrap(); + let response = backend.do_completion(&prompt).await.unwrap(); assert!(!response.insert_text.is_empty()); } } diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 15c4619..e8d832e 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -71,7 +71,7 @@ impl FileStore { .iter() .filter(|f| **f != current_document_uri) { - let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0); + let needed = characters.saturating_sub(rope.len_chars()); if needed == 0 { break; } @@ -99,7 +99,7 @@ impl FileStore { .clone(); let cursor_index = rope.line_to_char(position.position.line as usize) + position.position.character as usize; - let start = cursor_index.checked_sub(characters / 2).unwrap_or(0); + let start = cursor_index.saturating_sub(characters / 2); let end = rope .len_chars() .min(cursor_index + (characters - (cursor_index - start))); @@ -137,15 +137,15 @@ impl FileStore { if is_chat_enabled || rope.len_chars() != cursor_index => { let max_length = tokens_to_estimated_characters(max_context_length); - let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0); + let start = cursor_index.saturating_sub(max_length / 2); let end = rope .len_chars() .min(cursor_index + (max_length - (cursor_index - start))); if is_chat_enabled { - rope.insert(cursor_index, "{CURSOR}"); + rope.insert(cursor_index, ""); let rope_slice = rope - .get_slice(start..end + "{CURSOR}".chars().count()) + .get_slice(start..end + "".chars().count()) .context("Error getting rope slice")?; rope_slice.to_string() } else { @@ -166,9 +166,8 @@ impl FileStore { } } _ => { - let start = cursor_index - .checked_sub(tokens_to_estimated_characters(max_context_length)) - .unwrap_or(0); + let start = + cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length)); let rope_slice = rope .get_slice(start..cursor_index) .context("Error getting rope slice")?; @@ -178,9 +177,13 @@ impl FileStore { } } +#[async_trait::async_trait] impl MemoryBackend for FileStore { #[instrument(skip(self))] - fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + async fn get_filter_text( + &self, + position: &TextDocumentPositionParams, + ) -> anyhow::Result { let rope = self .file_map .get(position.text_document.uri.as_str()) @@ -193,7 +196,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn build_prompt( + async fn build_prompt( &mut self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, @@ -207,7 +210,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn opened_text_document( + async fn opened_text_document( &mut self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { @@ -219,7 +222,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn changed_text_document( + async fn changed_text_document( &mut self, params: lsp_types::DidChangeTextDocumentParams, ) -> anyhow::Result<()> { @@ -246,7 +249,7 @@ impl MemoryBackend for FileStore { } #[instrument(skip(self))] - fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { for file_rename in params.files { if let Some(rope) = self.file_map.remove(&file_rename.old_uri) { self.file_map.insert(file_rename.new_uri, rope); diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index 0931318..cb8c6b6 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -26,19 +26,29 @@ pub enum PromptForType { Generate, } +#[async_trait::async_trait] pub trait MemoryBackend { - fn init(&self) -> anyhow::Result<()> { + async fn init(&self) -> anyhow::Result<()> { Ok(()) } - fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; - fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>; - fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; - fn build_prompt( + async fn opened_text_document( + &mut self, + params: DidOpenTextDocumentParams, + ) -> anyhow::Result<()>; + async fn changed_text_document( + &mut self, + params: DidChangeTextDocumentParams, + ) -> anyhow::Result<()>; + async fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; + async fn build_prompt( &mut self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result; - fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result; + async fn get_filter_text( + &self, + position: &TextDocumentPositionParams, + ) -> anyhow::Result; } impl TryFrom for Box { @@ -55,3 +65,15 @@ impl TryFrom for Box { } } } + +// This makes testing much easier. Every transformer backend takes in a prompt. When verifying they work, its +// easier to just pass in a default prompt. +#[cfg(test)] +impl Prompt { + pub fn default_with_cursor() -> Self { + Self { + context: r#"def test_context():\n pass"#.to_string(), + code: r#"def test_code():\n "#.to_string(), + } + } +} diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index c7fce9e..da543e0 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -40,7 +40,7 @@ impl PostgresML { }; // TODO: Think on the naming of the collection // Maybe filter on metadata or I'm not sure - let collection = Collection::new("test-lsp-ai-2", Some(database_url))?; + let collection = Collection::new("test-lsp-ai-3", Some(database_url))?; // TODO: Review the pipeline let pipeline = Pipeline::new( "v1", @@ -50,7 +50,7 @@ impl PostgresML { "splitter": { "model": "recursive_character", "parameters": { - "chunk_size": 512, + "chunk_size": 1500, "chunk_overlap": 40 } }, @@ -90,7 +90,7 @@ impl PostgresML { .into_iter() .map(|path| { let text = std::fs::read_to_string(&path) - .expect(format!("Error reading path: {}", path).as_str()); + .unwrap_or_else(|_| panic!("Error reading path: {}", path)); json!({ "id": path, "text": text @@ -121,24 +121,28 @@ impl PostgresML { } } +#[async_trait::async_trait] impl MemoryBackend for PostgresML { #[instrument(skip(self))] - fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { - self.file_store.get_filter_text(position) + async fn get_filter_text( + &self, + position: &TextDocumentPositionParams, + ) -> anyhow::Result { + self.file_store.get_filter_text(position).await } #[instrument(skip(self))] - fn build_prompt( + async fn build_prompt( &mut self, position: &TextDocumentPositionParams, prompt_for_type: PromptForType, ) -> anyhow::Result { - // This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform let query = self .file_store .get_characters_around_position(position, 512)?; - let res = self.runtime.block_on( - self.collection.vector_search( + let res = self + .collection + .vector_search( json!({ "query": { "fields": { @@ -151,8 +155,8 @@ impl MemoryBackend for PostgresML { }) .into(), &mut self.pipeline, - ), - )?; + ) + .await?; let context = res .into_iter() .map(|c| { @@ -176,7 +180,7 @@ impl MemoryBackend for PostgresML { } #[instrument(skip(self))] - fn opened_text_document( + async fn opened_text_document( &mut self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { @@ -185,68 +189,63 @@ impl MemoryBackend for PostgresML { let task_added_pipeline = self.added_pipeline; let mut task_collection = self.collection.clone(); let mut task_pipeline = self.pipeline.clone(); - self.runtime.spawn(async move { - if !task_added_pipeline { - task_collection - .add_pipeline(&mut task_pipeline) - .await - .expect("PGML - Error adding pipeline to collection"); - } + if !task_added_pipeline { + task_collection + .add_pipeline(&mut task_pipeline) + .await + .expect("PGML - Error adding pipeline to collection"); + } + task_collection + .upsert_documents( + vec![json!({ + "id": path, + "text": text + }) + .into()], + None, + ) + .await + .expect("PGML - Error upserting documents"); + self.file_store.opened_text_document(params).await + } + + #[instrument(skip(self))] + async fn changed_text_document( + &mut self, + params: lsp_types::DidChangeTextDocumentParams, + ) -> anyhow::Result<()> { + let path = params.text_document.uri.path().to_owned(); + self.debounce_tx.send(path)?; + self.file_store.changed_text_document(params).await + } + + #[instrument(skip(self))] + async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + let mut task_collection = self.collection.clone(); + let task_params = params.clone(); + for file in task_params.files { + task_collection + .delete_documents( + json!({ + "id": file.old_uri + }) + .into(), + ) + .await + .expect("PGML - Error deleting file"); + let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); task_collection .upsert_documents( vec![json!({ - "id": path, + "id": file.new_uri, "text": text }) .into()], None, ) .await - .expect("PGML - Error upserting documents"); - }); - self.file_store.opened_text_document(params) - } - - #[instrument(skip(self))] - fn changed_text_document( - &mut self, - params: lsp_types::DidChangeTextDocumentParams, - ) -> anyhow::Result<()> { - let path = params.text_document.uri.path().to_owned(); - self.debounce_tx.send(path)?; - self.file_store.changed_text_document(params) - } - - #[instrument(skip(self))] - fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { - let mut task_collection = self.collection.clone(); - let task_params = params.clone(); - self.runtime.spawn(async move { - for file in task_params.files { - task_collection - .delete_documents( - json!({ - "id": file.old_uri - }) - .into(), - ) - .await - .expect("PGML - Error deleting file"); - let text = - std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); - task_collection - .upsert_documents( - vec![json!({ - "id": file.new_uri, - "text": text - }) - .into()], - None, - ) - .await - .expect("PGML - Error adding pipeline to collection"); - } - }); - self.file_store.renamed_file(params) + .expect("PGML - Error adding pipeline to collection"); + } + self.file_store.renamed_file(params).await } } diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index ed012a6..8e865ce 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -2,7 +2,6 @@ use anyhow::Context; use hf_hub::api::sync::ApiBuilder; use tracing::{debug, instrument}; -use super::TransformerBackend; use crate::{ configuration::{self}, memory_backends::Prompt, @@ -16,6 +15,8 @@ use crate::{ mod model; use model::Model; +use super::TransformerBackend; + pub struct LlamaCPP { model: Model, configuration: configuration::ModelGGUF, @@ -62,9 +63,10 @@ impl LlamaCPP { } } +#[async_trait::async_trait] impl TransformerBackend for LlamaCPP { #[instrument(skip(self))] - fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { // let prompt = self.get_prompt_string(prompt)?; let prompt = &prompt.code; debug!("Prompt string for LLM: {}", prompt); @@ -75,7 +77,7 @@ impl TransformerBackend for LlamaCPP { } #[instrument(skip(self))] - fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { // let prompt = self.get_prompt_string(prompt)?; // debug!("Prompt string for LLM: {}", prompt); let prompt = &prompt.code; @@ -86,7 +88,7 @@ impl TransformerBackend for LlamaCPP { } #[instrument(skip(self))] - fn do_generate_stream( + async fn do_generate_stream( &self, _request: &GenerateStreamRequest, ) -> anyhow::Result { diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 38eea40..8b02357 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -9,11 +9,11 @@ use crate::{ mod llama_cpp; mod openai; +#[async_trait::async_trait] pub trait TransformerBackend { - // Should all take an enum of chat messages or just a string for completion - fn do_completion(&self, prompt: &Prompt) -> anyhow::Result; - fn do_generate(&self, prompt: &Prompt) -> anyhow::Result; - fn do_generate_stream( + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result; + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result; + async fn do_generate_stream( &self, request: &GenerateStreamRequest, ) -> anyhow::Result; diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index 2ff53c0..dfbf5bf 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -1,16 +1,22 @@ +// Something more about what this file is +// NOTE: When decoding responses from OpenAI compatbile services, we don't care about every field + +use anyhow::Context; use serde::Deserialize; -use serde_json::json; +use serde_json::{json, Value}; use tracing::instrument; -use super::TransformerBackend; use crate::{ - configuration, + configuration::{self, ChatMessage}, memory_backends::Prompt, + utils::{format_chat_messages, format_context_code}, worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, }; +use super::TransformerBackend; + pub struct OpenAI { configuration: configuration::OpenAI, } @@ -22,7 +28,19 @@ struct OpenAICompletionsChoice { #[derive(Deserialize)] struct OpenAICompletionsResponse { - choices: Vec, + choices: Option>, + error: Option, +} + +#[derive(Deserialize)] +struct OpenAIChatChoices { + message: ChatMessage, +} + +#[derive(Deserialize)] +struct OpenAIChatResponse { + choices: Option>, + error: Option, } impl OpenAI { @@ -42,7 +60,12 @@ impl OpenAI { anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); }; let res: OpenAICompletionsResponse = client - .post(&self.configuration.completions_endpoint) + .post( + self.configuration + .completions_endpoint + .as_ref() + .context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?, + ) .bearer_auth(token) .header("Content-Type", "application/json") .header("Accept", "application/json") @@ -51,7 +74,6 @@ impl OpenAI { "max_tokens": max_tokens, "n": 1, "top_p": self.configuration.top_p, - "top_k": self.configuration.top_k, "presence_penalty": self.configuration.presence_penalty, "frequency_penalty": self.configuration.frequency_penalty, "temperature": self.configuration.temperature, @@ -60,34 +82,219 @@ impl OpenAI { })) .send()? .json()?; - eprintln!("**********RECEIVED REQUEST********"); - Ok(res.choices[0].text.clone()) + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(choices) = res.choices { + Ok(choices[0].text.clone()) + } else { + anyhow::bail!("Uknown error while making request to OpenAI") + } + } + + fn get_chat(&self, messages: Vec, max_tokens: usize) -> anyhow::Result { + eprintln!( + "SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******", + messages + ); + let client = reqwest::blocking::Client::new(); + let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + std::env::var(env_var_name)? + } else if let Some(token) = &self.configuration.auth_token { + token.to_string() + } else { + anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); + }; + let res: OpenAIChatResponse = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `completions_endpoint` to use completions")?, + ) + .bearer_auth(token) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .json(&json!({ + "model": self.configuration.model, + "max_tokens": max_tokens, + "n": 1, + "top_p": self.configuration.top_p, + "presence_penalty": self.configuration.presence_penalty, + "frequency_penalty": self.configuration.frequency_penalty, + "temperature": self.configuration.temperature, + "messages": messages + })) + .send()? + .json()?; + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(choices) = res.choices { + Ok(choices[0].message.content.clone()) + } else { + anyhow::bail!("Uknown error while making request to OpenAI") + } } } +#[async_trait::async_trait] impl TransformerBackend for OpenAI { #[instrument(skip(self))] - fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); - let prompt = format!("{} \n\n {}", prompt.context, prompt.code); - let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?; + let max_tokens = self.configuration.max_tokens.completion; + let insert_text = match &self.configuration.chat { + Some(c) => match &c.completion { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, prompt); + self.get_chat(messages, max_tokens)? + } + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }, + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }; Ok(DoCompletionResponse { insert_text }) } #[instrument(skip(self))] - fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); - let prompt = format!("{} \n\n {}", prompt.context, prompt.code); - let generated_text = - self.get_completion(&prompt, self.configuration.max_tokens.completion)?; + let max_tokens = self.configuration.max_tokens.generation; + let generated_text = match &self.configuration.chat { + Some(c) => match &c.generation { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, prompt); + self.get_chat(messages, max_tokens)? + } + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }, + None => self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + )?, + }; Ok(DoGenerateResponse { generated_text }) } #[instrument(skip(self))] - fn do_generate_stream( + async fn do_generate_stream( &self, request: &GenerateStreamRequest, ) -> anyhow::Result { unimplemented!() } } + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn openai_completion_do_completion() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "completions_endpoint": "https://api.openai.com/v1/completions", + "model": "gpt-3.5-turbo-instruct", + "auth_token_env_var_name": "OPENAI_API_KEY", + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_completion(&prompt).await?; + assert!(!response.insert_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn openai_chat_do_completion() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "chat_endpoint": "https://api.openai.com/v1/chat/completions", + "model": "gpt-3.5-turbo", + "auth_token_env_var_name": "OPENAI_API_KEY", + "chat": { + "completion": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ], + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_completion(&prompt).await?; + assert!(!response.insert_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn openai_completion_do_generate() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "completions_endpoint": "https://api.openai.com/v1/completions", + "model": "gpt-3.5-turbo-instruct", + "auth_token_env_var_name": "OPENAI_API_KEY", + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_generate(&prompt).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn openai_chat_do_generate() -> anyhow::Result<()> { + let configuration: configuration::OpenAI = serde_json::from_value(json!({ + "config": { + "chat_endpoint": "https://api.openai.com/v1/chat/completions", + "model": "gpt-3.5-turbo", + "auth_token_env_var_name": "OPENAI_API_KEY", + "chat": { + "generation": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ] + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }}))?; + let openai = OpenAI::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = openai.do_generate(&prompt).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } +} diff --git a/src/utils.rs b/src/utils.rs index ce33964..ff24a26 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize { tokens * 4 } -pub fn format_chat_messages(messages: &Vec, prompt: &Prompt) -> Vec { +pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec { messages .iter() .map(|m| ChatMessage { role: m.role.to_owned(), content: m .content - .replace("{context}", &prompt.context) - .replace("{code}", &prompt.code), + .replace("{CONTEXT}", &prompt.context) + .replace("{CODE}", &prompt.code), }) .collect() } + +pub fn format_context_code(context: &str, code: &str) -> String { + format!("{context}\n\n{code}") +} diff --git a/src/worker.rs b/src/worker.rs index 87f8fd8..020a596 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -71,7 +71,7 @@ pub struct DoGenerateStreamResponse { } pub struct Worker { - transformer_backend: Box, + transformer_backend: Box, memory_backend: Arc>>, last_worker_request: Arc>>, connection: Arc, @@ -79,7 +79,7 @@ pub struct Worker { impl Worker { pub fn new( - transformer_backend: Box, + transformer_backend: Box, memory_backend: Arc>>, last_worker_request: Arc>>, connection: Arc, @@ -92,8 +92,7 @@ impl Worker { } } - #[instrument(skip(self))] - fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result { + async fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result { let prompt = self.memory_backend.lock().build_prompt( &request.params.text_document_position, PromptForType::Completion, @@ -102,7 +101,7 @@ impl Worker { .memory_backend .lock() .get_filter_text(&request.params.text_document_position)?; - let response = self.transformer_backend.do_completion(&prompt)?; + let response = self.transformer_backend.do_completion(&prompt).await?; let completion_text_edit = TextEdit::new( Range::new( Position::new( @@ -128,7 +127,7 @@ impl Worker { items: vec![item], }; let result = Some(CompletionResponse::List(completion_list)); - let result = serde_json::to_value(&result).unwrap(); + let result = serde_json::to_value(result).unwrap(); Ok(Response { id: request.id.clone(), result: Some(result), @@ -137,16 +136,16 @@ impl Worker { } #[instrument(skip(self))] - fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result { + async fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result { let prompt = self.memory_backend.lock().build_prompt( &request.params.text_document_position, PromptForType::Generate, )?; - let response = self.transformer_backend.do_generate(&prompt)?; + let response = self.transformer_backend.do_generate(&prompt).await?; let result = GenerateResult { generated_text: response.generated_text, }; - let result = serde_json::to_value(&result).unwrap(); + let result = serde_json::to_value(result).unwrap(); Ok(Response { id: request.id.clone(), result: Some(result), @@ -154,36 +153,48 @@ impl Worker { }) } - pub fn run(self) { + pub fn run(self) -> anyhow::Result<()> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build()?; loop { let option_worker_request: Option = { let mut completion_request = self.last_worker_request.lock(); std::mem::take(&mut *completion_request) }; if let Some(request) = option_worker_request { - let response = match request { - WorkerRequest::Completion(request) => match self.do_completion(&request) { - Ok(r) => r, - Err(e) => Response { - id: request.id, - result: None, - error: Some(e.to_response_error(-32603)), - }, - }, - WorkerRequest::Generate(request) => match self.do_generate(&request) { - Ok(r) => r, - Err(e) => Response { - id: request.id, - result: None, - error: Some(e.to_response_error(-32603)), - }, - }, - WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"), - }; - self.connection - .sender - .send(Message::Response(response)) - .expect("Error sending message"); + runtime.spawn(async move { + let response = match request { + WorkerRequest::Completion(request) => { + match self.do_completion(&request).await { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + } + } + WorkerRequest::Generate(request) => { + match self.do_generate(&request).await { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + } + } + WorkerRequest::GenerateStream(_) => { + panic!("Streaming is not supported yet") + } + }; + self.connection + .sender + .send(Message::Response(response)) + .expect("Error sending message"); + }); } thread::sleep(std::time::Duration::from_millis(5)); }