diff --git a/design.txt b/design.txt new file mode 100644 index 0000000..d9595c0 --- /dev/null +++ b/design.txt @@ -0,0 +1,23 @@ +# Overview + +LSP AI should support multiple transform_backends: +- Python - LLAMA CPP +- Python - SOME OTHER LIBRARY +- PostgresML + +pub trait TransformBackend { + // These all take memory backends as an argument + do_completion() + do_generate() + do_generate_stream() +} + +LSP AI should support multiple memory_backends: +- SIMPLE FILE STORE +- IN MEMORY VECTOR STORE +- PostgresML + +pub trait MemoryBackend { + // Some file change ones + get_context() // Depending on the memory backend this will do very different things +} diff --git a/editors/vscode/package.json b/editors/vscode/package.json index 29e25da..f143db2 100644 --- a/editors/vscode/package.json +++ b/editors/vscode/package.json @@ -24,7 +24,17 @@ "command": "lsp-ai.generateStream", "title": "LSP AI Generate Stream" } - ] + ], + "configuration": { + "title": "Configuration", + "properties": { + "configuration.json": { + "type": "json", + "default": "{}", + "description": "JSON configuration for LSP AI" + } + } + } }, "devDependencies": { "@types/node": "^20.11.0", diff --git a/src/configuration.rs b/src/configuration.rs new file mode 100644 index 0000000..bd3c24c --- /dev/null +++ b/src/configuration.rs @@ -0,0 +1,181 @@ +use anyhow::{Context, Result}; +use serde::Deserialize; +use serde_json::Value; +use std::collections::HashMap; + +#[cfg(target_os = "macos")] +const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024; + +const DEFAULT_MAX_COMPLETION_TOKENS: usize = 32; +const DEFAULT_MAX_GENERATION_TOKENS: usize = 256; + +pub enum ValidMemoryBackend { + FileStore, + PostgresML, +} + +pub enum ValidTransformerBackend { + LlamaCPP, + PostgresML, +} + +// TODO: Review this for real lol +#[derive(Clone, Deserialize)] +pub struct FIM { + prefix: String, + middle: String, + suffix: String, +} + +// TODO: Add some default things +#[derive(Clone, Deserialize)] +pub struct MaxNewTokens { + pub completion: usize, + pub generation: usize, +} + +impl Default for MaxNewTokens { + fn default() -> Self { + Self { + completion: DEFAULT_MAX_COMPLETION_TOKENS, + generation: DEFAULT_MAX_GENERATION_TOKENS, + } + } +} + +#[derive(Clone, Deserialize)] +struct ValidMemoryConfiguration { + file_store: Option, +} + +#[derive(Clone, Deserialize)] +struct ModelGGUF { + repository: String, + name: String, + // Fill in the middle support + fim: Option, + // The maximum number of new tokens to generate + #[serde(default)] + max_new_tokens: MaxNewTokens, + // Kwargs passed to LlamaCPP + #[serde(flatten)] + kwargs: HashMap, +} + +#[derive(Clone, Deserialize)] +struct ValidMacTransformerConfiguration { + model_gguf: Option, +} + +#[derive(Clone, Deserialize)] +struct ValidLinuxTransformerConfiguration { + model_gguf: Option, +} + +#[derive(Clone, Deserialize)] +struct ValidConfiguration { + memory: ValidMemoryConfiguration, + // TODO: Add renam here + #[cfg(target_os = "macos")] + #[serde(alias = "macos")] + transformer: ValidMacTransformerConfiguration, + #[cfg(target_os = "linux")] + #[serde(alias = "linux")] + transformer: ValidLinuxTransformerConfiguration, +} + +#[derive(Clone)] +pub struct Configuration { + valid_config: ValidConfiguration, +} + +impl Configuration { + pub fn new(mut args: Value) -> Result { + let configuration_args = args + .as_object_mut() + .context("Server configuration must be a JSON object")? + .remove("initializationOptions") + .unwrap_or_default(); + let valid_args: ValidConfiguration = serde_json::from_value(configuration_args)?; + // TODO: Make sure they only specified one model or something ya know + Ok(Self { + valid_config: valid_args, + }) + } + + pub fn get_memory_backend(&self) -> Result { + if self.valid_config.memory.file_store.is_some() { + Ok(ValidMemoryBackend::FileStore) + } else { + anyhow::bail!("Invalid memory configuration") + } + } + + pub fn get_transformer_backend(&self) -> Result { + if self.valid_config.transformer.model_gguf.is_some() { + Ok(ValidTransformerBackend::LlamaCPP) + } else { + anyhow::bail!("Invalid model configuration") + } + } + + pub fn get_maximum_context_length(&self) -> usize { + if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { + model_gguf + .kwargs + .get("n_ctx") + .map(|v| { + v.as_u64() + .map(|u| u as usize) + .unwrap_or(DEFAULT_LLAMA_CPP_N_CTX) + }) + .unwrap_or(DEFAULT_LLAMA_CPP_N_CTX) + } else { + panic!("We currently only support gguf models using llama cpp") + } + } + + pub fn get_max_new_tokens(&self) -> &MaxNewTokens { + if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { + &model_gguf.max_new_tokens + } else { + panic!("We currently only support gguf models using llama cpp") + } + } + + pub fn supports_fim(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn custom_mac_gguf_model() { + let args = json!({ + "initializationOptions": { + "memory": { + "file_store": {} + }, + "macos": { + "model_gguf": { + "repository": "deepseek-coder-6.7b-base", + "name": "Q4_K_M.gguf", + "max_new_tokens": { + "completion": 32, + "generation": 256, + }, + "n_ctx": 2048, + "n_threads": 8, + "n_gpu_layers": 35, + "chat_template": "", + } + }, + } + }); + let _ = Configuration::new(args).unwrap(); + } +} diff --git a/src/custom_requests/generate_stream.rs b/src/custom_requests/generate_stream.rs index 42e551b..14d73d5 100644 --- a/src/custom_requests/generate_stream.rs +++ b/src/custom_requests/generate_stream.rs @@ -1,4 +1,4 @@ -use lsp_types::{PartialResultParams, ProgressToken, TextDocumentPositionParams}; +use lsp_types::{ProgressToken, TextDocumentPositionParams}; use serde::{Deserialize, Serialize}; pub enum GenerateStream {} diff --git a/src/main.rs b/src/main.rs index 3407f3e..a8186b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,31 +1,28 @@ -use anyhow::{Context, Result}; +use anyhow::Result; + use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId}; use lsp_types::{ request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, }; -use once_cell::sync::Lazy; use parking_lot::Mutex; -use pyo3::prelude::*; -use ropey::Rope; -use serde::Deserialize; -use std::{collections::HashMap, sync::Arc, thread}; +use std::{sync::Arc, thread}; +mod configuration; mod custom_requests; +mod memory_backends; +mod transformer_backends; +mod utils; mod worker; +use configuration::Configuration; use custom_requests::generate::Generate; -use worker::{CompletionRequest, GenerateRequest, WorkerRequest}; +use memory_backends::MemoryBackend; +use transformer_backends::TransformerBackend; +use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest}; use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest}; -pub static PY_MODULE: Lazy>> = Lazy::new(|| { - pyo3::Python::with_gil(|py| -> Result> { - let src = include_str!("python/transformers.py"); - Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into()) - }) -}); - // Taken directly from: https://github.com/rust-lang/rust-analyzer fn notification_is(notification: &Notification) -> bool { notification.method == N::METHOD @@ -52,55 +49,47 @@ fn main() -> Result<()> { )), ..Default::default() })?; - let initialization_params = connection.initialize(server_capabilities)?; + let initialization_args = connection.initialize(server_capabilities)?; - // Activate the python venv - Python::with_gil(|py| -> Result<()> { - let activate: Py = PY_MODULE - .as_ref() - .map_err(anyhow::Error::msg)? - .getattr(py, "activate_venv")?; - - activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?; - Ok(()) - })?; - - main_loop(connection, initialization_params)?; + main_loop(connection, initialization_args)?; io_threads.join()?; Ok(()) } -#[derive(Deserialize)] -struct Params {} - // This main loop is tricky // We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get // Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request -fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { - let _params: Params = serde_json::from_value(params)?; +// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations +fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { + let args = Configuration::new(args)?; - // Set the model - Python::with_gil(|py| -> Result<()> { - let activate: Py = PY_MODULE - .as_ref() - .map_err(anyhow::Error::msg)? - .getattr(py, "set_model")?; - activate.call1(py, ("",))?; - Ok(()) - })?; + // Set the transformer_backend + let transformer_backend: Box = args.clone().try_into()?; + transformer_backend.init()?; - // Prep variables + // Set the memory_backend + let memory_backend: Arc>> = + Arc::new(Mutex::new(args.clone().try_into()?)); + + // Wrap the connection for sharing between threads let connection = Arc::new(connection); - let mut file_map: HashMap = HashMap::new(); // How we communicate between the worker and receiver threads let last_worker_request = Arc::new(Mutex::new(None)); // Thread local variables + let thread_memory_backend = memory_backend.clone(); let thread_last_worker_request = last_worker_request.clone(); let thread_connection = connection.clone(); + // TODO: Pass some backend into here thread::spawn(move || { - worker::run(thread_last_worker_request, thread_connection); + Worker::new( + transformer_backend, + thread_memory_backend, + thread_last_worker_request, + thread_connection, + ) + .run(); }); for msg in &connection.receiver { @@ -115,41 +104,30 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { if request_is::(&req) { match cast::(req) { Ok((id, params)) => { - let rope = file_map - .get(params.text_document_position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); + eprintln!("******{:?}********", id); let mut lcr = last_worker_request.lock(); - let completion_request = CompletionRequest::new(id, params, rope); + let completion_request = CompletionRequest::new(id, params); *lcr = Some(WorkerRequest::Completion(completion_request)); } - Err(err) => panic!("{err:?}"), + Err(err) => eprintln!("{err:?}"), } } else if request_is::(&req) { match cast::(req) { Ok((id, params)) => { - let rope = file_map - .get(params.text_document_position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); let mut lcr = last_worker_request.lock(); - let completion_request = GenerateRequest::new(id, params, rope); + let completion_request = GenerateRequest::new(id, params); *lcr = Some(WorkerRequest::Generate(completion_request)); } - Err(err) => panic!("{err:?}"), + Err(err) => eprintln!("{err:?}"), } } else if request_is::(&req) { match cast::(req) { Ok((id, params)) => { - let rope = file_map - .get(params.text_document_position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); let mut lcr = last_worker_request.lock(); - let completion_request = GenerateStreamRequest::new(id, params, rope); + let completion_request = GenerateStreamRequest::new(id, params); *lcr = Some(WorkerRequest::GenerateStream(completion_request)); } - Err(err) => panic!("{err:?}"), + Err(err) => eprintln!("{err:?}"), } } else { eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream") @@ -158,33 +136,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { Message::Notification(not) => { if notification_is::(¬) { let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?; - let rope = Rope::from_str(¶ms.text_document.text); - file_map.insert(params.text_document.uri.to_string(), rope); + memory_backend.lock().opened_text_document(params)?; } else if notification_is::(¬) { let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?; - let rope = file_map - .get_mut(params.text_document.uri.as_str()) - .context("Error trying to get file that does not exist")?; - for change in params.content_changes { - // If range is ommitted, text is the new text of the document - if let Some(range) = change.range { - let start_index = rope.line_to_char(range.start.line as usize) - + range.start.character as usize; - let end_index = rope.line_to_char(range.end.line as usize) - + range.end.character as usize; - rope.remove(start_index..end_index); - rope.insert(start_index, &change.text); - } else { - *rope = Rope::from_str(&change.text); - } - } + memory_backend.lock().changed_text_document(params)?; } else if notification_is::(¬) { let params: RenameFilesParams = serde_json::from_value(not.params)?; - for file_rename in params.files { - if let Some(rope) = file_map.remove(&file_rename.old_uri) { - file_map.insert(file_rename.new_uri, rope); - } - } + memory_backend.lock().renamed_file(params)?; } } _ => (), diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs new file mode 100644 index 0000000..ca88cfd --- /dev/null +++ b/src/memory_backends/file_store.rs @@ -0,0 +1,98 @@ +use anyhow::Context; +use lsp_types::TextDocumentPositionParams; +use ropey::Rope; +use std::collections::HashMap; + +use crate::configuration::Configuration; + +use super::MemoryBackend; + +pub struct FileStore { + configuration: Configuration, + file_map: HashMap, +} + +impl FileStore { + pub fn new(configuration: Configuration) -> Self { + Self { + configuration, + file_map: HashMap::new(), + } + } +} + +impl MemoryBackend for FileStore { + fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + let rope = self + .file_map + .get(position.text_document.uri.as_str()) + .context("Error file not found")? + .clone(); + + if self.configuration.supports_fim() { + // We will want to have some kind of infill support we add + // rope.insert(cursor_index, "<|fim_hole|>"); + // rope.insert(0, "<|fim_start|>"); + // rope.insert(rope.len_chars(), "<|fim_end|>"); + // let prompt = rope.to_string(); + unimplemented!() + } else { + // Convert rope to correct prompt for llm + let cursor_index = rope.line_to_char(position.position.line as usize) + + position.position.character as usize; + + let start = cursor_index + .checked_sub(self.configuration.get_maximum_context_length()) + .unwrap_or(0); + eprintln!("############ {start} - {cursor_index} #############"); + + Ok(rope + .get_slice(start..cursor_index) + .context("Error getting rope slice")? + .to_string()) + } + } + + fn opened_text_document( + &mut self, + params: lsp_types::DidOpenTextDocumentParams, + ) -> anyhow::Result<()> { + let rope = Rope::from_str(¶ms.text_document.text); + self.file_map + .insert(params.text_document.uri.to_string(), rope); + Ok(()) + } + + fn changed_text_document( + &mut self, + params: lsp_types::DidChangeTextDocumentParams, + ) -> anyhow::Result<()> { + let rope = self + .file_map + .get_mut(params.text_document.uri.as_str()) + .context("Error trying to get file that does not exist")?; + for change in params.content_changes { + // If range is ommitted, text is the new text of the document + if let Some(range) = change.range { + let start_index = + rope.line_to_char(range.start.line as usize) + range.start.character as usize; + let end_index = + rope.line_to_char(range.end.line as usize) + range.end.character as usize; + rope.remove(start_index..end_index); + rope.insert(start_index, &change.text); + } else { + *rope = Rope::from_str(&change.text); + } + } + Ok(()) + } + + fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + for file_rename in params.files { + if let Some(rope) = self.file_map.remove(&file_rename.old_uri) { + self.file_map.insert(file_rename.new_uri, rope); + } + } + Ok(()) + } +} diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs new file mode 100644 index 0000000..9d1de07 --- /dev/null +++ b/src/memory_backends/mod.rs @@ -0,0 +1,31 @@ +use lsp_types::{ + DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, + TextDocumentPositionParams, +}; + +use crate::configuration::{Configuration, ValidMemoryBackend}; + +pub mod file_store; + +pub trait MemoryBackend { + fn init(&self) -> anyhow::Result<()> { + Ok(()) + } + fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; + fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>; + fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; + fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result; +} + +impl TryFrom for Box { + type Error = anyhow::Error; + + fn try_from(configuration: Configuration) -> Result { + match configuration.get_memory_backend()? { + ValidMemoryBackend::FileStore => { + Ok(Box::new(file_store::FileStore::new(configuration))) + } + _ => unimplemented!(), + } + } +} diff --git a/src/python/transformers.py b/src/python/transformers.py deleted file mode 100644 index de2daeb..0000000 --- a/src/python/transformers.py +++ /dev/null @@ -1,44 +0,0 @@ -import sys -import os - -from llama_cpp import Llama - - -model = None - - -def activate_venv(venv): - if sys.platform in ('win32', 'win64', 'cygwin'): - activate_this = os.path.join(venv, 'Scripts', 'activate_this.py') - else: - activate_this = os.path.join(venv, 'bin', 'activate_this.py') - - if os.path.exists(activate_this): - exec(open(activate_this).read(), dict(__file__=activate_this)) - return True - else: - print(f"Virtualenv not found: {venv}", file=sys.stderr) - return False - - - -def set_model(filler): - global model - model = Llama( - # model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first - model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first - n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources - n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance - n_gpu_layers=35 # The number of layers to offload to GPU, if you have GPU acceleration available - ) - - -def transform(input): - # Simple inference example - output = model( - input, # Prompt - max_tokens=32, # Generate up to 512 tokens - stop=["<|EOT|>"], # Example stop token - not necessarily correct for this specific model! Please check before using. - echo=False # Whether to echo the prompt - ) - return output["choices"][0]["text"] diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs new file mode 100644 index 0000000..6675d10 --- /dev/null +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -0,0 +1,88 @@ +use crate::{ + configuration::Configuration, + worker::{ + DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, + }, +}; + +use super::TransformerBackend; +use once_cell::sync::Lazy; +use pyo3::prelude::*; + +pub static PY_MODULE: Lazy>> = Lazy::new(|| { + pyo3::Python::with_gil(|py| -> anyhow::Result> { + let src = include_str!("python/transformers.py"); + Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into()) + }) +}); + +pub struct LlamaCPP { + configuration: Configuration, +} + +impl LlamaCPP { + pub fn new(configuration: Configuration) -> Self { + Self { configuration } + } +} + +impl TransformerBackend for LlamaCPP { + fn init(&self) -> anyhow::Result<()> { + // Activate the python venv + Python::with_gil(|py| -> anyhow::Result<()> { + let activate: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "activate_venv")?; + activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?; + Ok(()) + })?; + // Set the model + Python::with_gil(|py| -> anyhow::Result<()> { + let activate: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "set_model")?; + activate.call1(py, ("",))?; + Ok(()) + })?; + Ok(()) + } + + fn do_completion(&self, prompt: &str) -> anyhow::Result { + let max_new_tokens = self.configuration.get_max_new_tokens().completion; + Python::with_gil(|py| -> anyhow::Result { + let transform: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "transform")?; + + let out: String = transform.call1(py, (prompt, max_new_tokens))?.extract(py)?; + Ok(out) + }) + .map(|insert_text| DoCompletionResponse { insert_text }) + } + + fn do_generate(&self, prompt: &str) -> anyhow::Result { + let max_new_tokens = self.configuration.get_max_new_tokens().generation; + Python::with_gil(|py| -> anyhow::Result { + let transform: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "transform")?; + + let out: String = transform.call1(py, (prompt, max_new_tokens))?.extract(py)?; + Ok(out) + }) + .map(|generated_text| DoGenerateResponse { generated_text }) + } + + fn do_generate_stream( + &self, + request: &GenerateStreamRequest, + ) -> anyhow::Result { + Ok(DoGenerateStreamResponse { + generated_text: "".to_string(), + }) + } +} diff --git a/src/transformer_backends/llama_cpp/python/transformers.py b/src/transformer_backends/llama_cpp/python/transformers.py new file mode 100644 index 0000000..c6ef9bf --- /dev/null +++ b/src/transformer_backends/llama_cpp/python/transformers.py @@ -0,0 +1,45 @@ +import sys +import os + +from llama_cpp import Llama + + +model = None + + +def activate_venv(venv): + if sys.platform in ("win32", "win64", "cygwin"): + activate_this = os.path.join(venv, "Scripts", "activate_this.py") + else: + activate_this = os.path.join(venv, "bin", "activate_this.py") + + if os.path.exists(activate_this): + exec(open(activate_this).read(), dict(__file__=activate_this)) + return True + else: + print(f"Virtualenv not found: {venv}", file=sys.stderr) + return False + + +def set_model(filler): + global model + model = Llama( + # model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first + model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first + n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources + n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance + n_gpu_layers=35, # The number of layers to offload to GPU, if you have GPU acceleration available + ) + + +def transform(input, max_tokens): + # Simple inference example + output = model( + input, # Prompt + max_tokens=max_tokens, # Generate up to max tokens + # stop=[ + # "<|EOT|>" + # ], # Example stop token - not necessarily correct for this specific model! Please check before using. + echo=False, # Whether to echo the prompt + ) + return output["choices"][0]["text"] diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs new file mode 100644 index 0000000..1c1227d --- /dev/null +++ b/src/transformer_backends/mod.rs @@ -0,0 +1,32 @@ +use crate::{ + configuration::{Configuration, ValidTransformerBackend}, + worker::{ + DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateRequest, + GenerateStreamRequest, + }, +}; + +pub mod llama_cpp; + +pub trait TransformerBackend { + fn init(&self) -> anyhow::Result<()>; + fn do_completion(&self, prompt: &str) -> anyhow::Result; + fn do_generate(&self, prompt: &str) -> anyhow::Result; + fn do_generate_stream( + &self, + request: &GenerateStreamRequest, + ) -> anyhow::Result; +} + +impl TryFrom for Box { + type Error = anyhow::Error; + + fn try_from(configuration: Configuration) -> Result { + match configuration.get_transformer_backend()? { + ValidTransformerBackend::LlamaCPP => { + Ok(Box::new(llama_cpp::LlamaCPP::new(configuration))) + } + _ => unimplemented!(), + } + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..cb2ae49 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,15 @@ +use lsp_server::ResponseError; + +pub trait ToResponseError { + fn to_response_error(&self, code: i32) -> ResponseError; +} + +impl ToResponseError for anyhow::Error { + fn to_response_error(&self, code: i32) -> ResponseError { + ResponseError { + code: -32603, + message: self.to_string(), + data: None, + } + } +} diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 0000000..ceddfbe --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,187 @@ +use lsp_server::{Connection, Message, RequestId, Response}; +use lsp_types::{ + CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, + Position, Range, TextEdit, +}; +use parking_lot::Mutex; +use std::{sync::Arc, thread}; + +use crate::custom_requests::generate::{GenerateParams, GenerateResult}; +use crate::custom_requests::generate_stream::{GenerateStreamParams, GenerateStreamResult}; +use crate::memory_backends::MemoryBackend; +use crate::transformer_backends::TransformerBackend; +use crate::utils::ToResponseError; + +#[derive(Clone)] +pub struct CompletionRequest { + id: RequestId, + params: CompletionParams, +} + +impl CompletionRequest { + pub fn new(id: RequestId, params: CompletionParams) -> Self { + Self { id, params } + } +} + +#[derive(Clone)] +pub struct GenerateRequest { + id: RequestId, + params: GenerateParams, +} + +impl GenerateRequest { + pub fn new(id: RequestId, params: GenerateParams) -> Self { + Self { id, params } + } +} + +#[derive(Clone)] +pub struct GenerateStreamRequest { + id: RequestId, + params: GenerateStreamParams, +} + +impl GenerateStreamRequest { + pub fn new(id: RequestId, params: GenerateStreamParams) -> Self { + Self { id, params } + } +} + +#[derive(Clone)] +pub enum WorkerRequest { + Completion(CompletionRequest), + Generate(GenerateRequest), + GenerateStream(GenerateStreamRequest), +} + +pub struct DoCompletionResponse { + pub insert_text: String, +} + +pub struct DoGenerateResponse { + pub generated_text: String, +} + +pub struct DoGenerateStreamResponse { + pub generated_text: String, +} + +pub struct Worker { + transformer_backend: Box, + memory_backend: Arc>>, + last_worker_request: Arc>>, + connection: Arc, +} + +impl Worker { + pub fn new( + transformer_backend: Box, + memory_backend: Arc>>, + last_worker_request: Arc>>, + connection: Arc, + ) -> Self { + Self { + transformer_backend, + memory_backend, + last_worker_request, + connection, + } + } + + fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result { + let prompt = self + .memory_backend + .lock() + .build_prompt(&request.params.text_document_position)?; + eprintln!("\n\n****************{}***************\n\n", prompt); + let response = self.transformer_backend.do_completion(&prompt)?; + eprintln!( + "\n\n****************{}***************\n\n", + response.insert_text + ); + let completion_text_edit = TextEdit::new( + Range::new( + Position::new( + request.params.text_document_position.position.line, + request.params.text_document_position.position.character, + ), + Position::new( + request.params.text_document_position.position.line, + request.params.text_document_position.position.character, + ), + ), + response.insert_text.clone(), + ); + let item = CompletionItem { + label: format!("ai - {}", response.insert_text), + text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)), + kind: Some(CompletionItemKind::TEXT), + ..Default::default() + }; + let completion_list = CompletionList { + is_incomplete: false, + items: vec![item], + }; + let result = Some(CompletionResponse::List(completion_list)); + let result = serde_json::to_value(&result).unwrap(); + Ok(Response { + id: request.id.clone(), + result: Some(result), + error: None, + }) + } + + fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result { + let prompt = self + .memory_backend + .lock() + .build_prompt(&request.params.text_document_position)?; + eprintln!("\n\n****************{}***************\n\n", prompt); + let response = self.transformer_backend.do_generate(&prompt)?; + let result = GenerateResult { + generated_text: response.generated_text, + }; + let result = serde_json::to_value(&result).unwrap(); + Ok(Response { + id: request.id.clone(), + result: Some(result), + error: None, + }) + } + + pub fn run(self) { + loop { + let option_worker_request: Option = { + let mut completion_request = self.last_worker_request.lock(); + std::mem::take(&mut *completion_request) + }; + if let Some(request) = option_worker_request { + let response = match request { + WorkerRequest::Completion(request) => match self.do_completion(&request) { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + }, + WorkerRequest::Generate(request) => match self.do_generate(&request) { + Ok(r) => r, + Err(e) => Response { + id: request.id, + result: None, + error: Some(e.to_response_error(-32603)), + }, + }, + WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"), + }; + self.connection + .sender + .send(Message::Response(response)) + .expect("Error sending message"); + } + thread::sleep(std::time::Duration::from_millis(5)); + } + } +} diff --git a/src/worker/completion.rs b/src/worker/completion.rs deleted file mode 100644 index 58de2f7..0000000 --- a/src/worker/completion.rs +++ /dev/null @@ -1,61 +0,0 @@ -use lsp_server::ResponseError; -use pyo3::prelude::*; - -use super::CompletionRequest; -use crate::PY_MODULE; - -pub struct DoCompletionResponse { - pub insert_text: String, - pub filter_text: String, -} - -pub fn do_completion(request: &CompletionRequest) -> Result { - let filter_text = request - .rope - .get_line(request.params.text_document_position.position.line as usize) - .ok_or(ResponseError { - code: -32603, // Maybe we want a different error code here? - message: "Error getting line in requested document".to_string(), - data: None, - })? - .to_string(); - - // Convert rope to correct prompt for llm - let cursor_index = request - .rope - .line_to_char(request.params.text_document_position.position.line as usize) - + request.params.text_document_position.position.character as usize; - - // We will want to have some kind of infill support we add - // rope.insert(cursor_index, "<|fim_hole|>"); - // rope.insert(0, "<|fim_start|>"); - // rope.insert(rope.len_chars(), "<|fim_end|>"); - // let prompt = rope.to_string(); - - let prompt = request - .rope - .get_slice(0..cursor_index) - .expect("Error getting rope slice") - .to_string(); - - eprintln!("\n\n****{prompt}****\n\n"); - - Python::with_gil(|py| -> anyhow::Result { - let transform: Py = PY_MODULE - .as_ref() - .map_err(anyhow::Error::msg)? - .getattr(py, "transform")?; - - let out: String = transform.call1(py, (prompt,))?.extract(py)?; - Ok(out) - }) - .map(|insert_text| DoCompletionResponse { - insert_text, - filter_text, - }) - .map_err(|e| ResponseError { - code: -32603, - message: e.to_string(), - data: None, - }) -} diff --git a/src/worker/generate.rs b/src/worker/generate.rs deleted file mode 100644 index 917c42d..0000000 --- a/src/worker/generate.rs +++ /dev/null @@ -1,47 +0,0 @@ -use lsp_server::ResponseError; -use pyo3::prelude::*; - -use super::{GenerateRequest, GenerateStreamRequest}; -use crate::PY_MODULE; - -pub struct DoGenerateResponse { - pub generated_text: String, -} - -pub fn do_generate(request: &GenerateRequest) -> Result { - // Convert rope to correct prompt for llm - let cursor_index = request - .rope - .line_to_char(request.params.text_document_position.position.line as usize) - + request.params.text_document_position.position.character as usize; - - // We will want to have some kind of infill support we add - // rope.insert(cursor_index, "<|fim_hole|>"); - // rope.insert(0, "<|fim_start|>"); - // rope.insert(rope.len_chars(), "<|fim_end|>"); - // let prompt = rope.to_string(); - - let prompt = request - .rope - .get_slice(0..cursor_index) - .expect("Error getting rope slice") - .to_string(); - - eprintln!("\n\n****{prompt}****\n\n"); - - Python::with_gil(|py| -> anyhow::Result { - let transform: Py = PY_MODULE - .as_ref() - .map_err(anyhow::Error::msg)? - .getattr(py, "transform")?; - - let out: String = transform.call1(py, (prompt,))?.extract(py)?; - Ok(out) - }) - .map(|generated_text| DoGenerateResponse { generated_text }) - .map_err(|e| ResponseError { - code: -32603, - message: e.to_string(), - data: None, - }) -} diff --git a/src/worker/generate_stream.rs b/src/worker/generate_stream.rs deleted file mode 100644 index 42981dd..0000000 --- a/src/worker/generate_stream.rs +++ /dev/null @@ -1,45 +0,0 @@ -use lsp_server::ResponseError; -use pyo3::prelude::*; - -use super::{GenerateRequest, GenerateStreamRequest}; -use crate::PY_MODULE; - -pub fn do_generate_stream(request: &GenerateStreamRequest) -> Result<(), ResponseError> { - // Convert rope to correct prompt for llm - // let cursor_index = request - // .rope - // .line_to_char(request.params.text_document_position.position.line as usize) - // + request.params.text_document_position.position.character as usize; - - // // We will want to have some kind of infill support we add - // // rope.insert(cursor_index, "<|fim_hole|>"); - // // rope.insert(0, "<|fim_start|>"); - // // rope.insert(rope.len_chars(), "<|fim_end|>"); - // // let prompt = rope.to_string(); - - // let prompt = request - // .rope - // .get_slice(0..cursor_index) - // .expect("Error getting rope slice") - // .to_string(); - - // eprintln!("\n\n****{prompt}****\n\n"); - - // Python::with_gil(|py| -> anyhow::Result { - // let transform: Py = PY_MODULE - // .as_ref() - // .map_err(anyhow::Error::msg)? - // .getattr(py, "transform")?; - - // let out: String = transform.call1(py, (prompt,))?.extract(py)?; - // Ok(out) - // }) - // .map(|generated_text| DoGenerateResponse { generated_text }) - // .map_err(|e| ResponseError { - // code: -32603, - // message: e.to_string(), - // data: None, - // }) - - Ok(()) -} diff --git a/src/worker/mod.rs b/src/worker/mod.rs deleted file mode 100644 index bfe1223..0000000 --- a/src/worker/mod.rs +++ /dev/null @@ -1,181 +0,0 @@ -use lsp_server::{Connection, Message, RequestId, Response}; -use lsp_types::{ - CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, - Position, Range, TextEdit, -}; -use parking_lot::Mutex; -use ropey::Rope; -use std::{sync::Arc, thread}; - -mod completion; -mod generate; -mod generate_stream; - -use crate::custom_requests::generate::{GenerateParams, GenerateResult}; -use crate::custom_requests::generate_stream::{GenerateStreamParams, GenerateStreamResult}; -use completion::do_completion; -use generate::do_generate; -use generate_stream::do_generate_stream; - -#[derive(Clone)] -pub struct CompletionRequest { - id: RequestId, - params: CompletionParams, - rope: Rope, -} - -impl CompletionRequest { - pub fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self { - Self { id, params, rope } - } -} - -#[derive(Clone)] -pub struct GenerateRequest { - id: RequestId, - params: GenerateParams, - rope: Rope, -} - -impl GenerateRequest { - pub fn new(id: RequestId, params: GenerateParams, rope: Rope) -> Self { - Self { id, params, rope } - } -} - -#[derive(Clone)] -pub struct GenerateStreamRequest { - id: RequestId, - params: GenerateStreamParams, - rope: Rope, -} - -impl GenerateStreamRequest { - pub fn new(id: RequestId, params: GenerateStreamParams, rope: Rope) -> Self { - Self { id, params, rope } - } -} - -#[derive(Clone)] -pub enum WorkerRequest { - Completion(CompletionRequest), - Generate(GenerateRequest), - GenerateStream(GenerateStreamRequest), -} - -pub fn run(last_worker_request: Arc>>, connection: Arc) { - loop { - let option_worker_request: Option = { - let mut completion_request = last_worker_request.lock(); - std::mem::take(&mut *completion_request) - }; - if let Some(request) = option_worker_request { - let response = match request { - WorkerRequest::Completion(request) => match do_completion(&request) { - Ok(response) => { - let completion_text_edit = TextEdit::new( - Range::new( - Position::new( - request.params.text_document_position.position.line, - request.params.text_document_position.position.character, - ), - Position::new( - request.params.text_document_position.position.line, - request.params.text_document_position.position.character, - ), - ), - response.insert_text.clone(), - ); - let item = CompletionItem { - label: format!("ai - {}", response.insert_text), - filter_text: Some(response.filter_text), - text_edit: Some(lsp_types::CompletionTextEdit::Edit( - completion_text_edit, - )), - kind: Some(CompletionItemKind::TEXT), - ..Default::default() - }; - let completion_list = CompletionList { - is_incomplete: false, - items: vec![item], - }; - let result = Some(CompletionResponse::List(completion_list)); - let result = serde_json::to_value(&result).unwrap(); - Response { - id: request.id, - result: Some(result), - error: None, - } - } - Err(e) => Response { - id: request.id, - result: None, - error: Some(e), - }, - }, - WorkerRequest::Generate(request) => match do_generate(&request) { - Ok(result) => { - let result = GenerateResult { - generated_text: result.generated_text, - }; - let result = serde_json::to_value(&result).unwrap(); - Response { - id: request.id, - result: Some(result), - error: None, - } - } - Err(e) => Response { - id: request.id, - result: None, - error: Some(e), - }, - }, - WorkerRequest::GenerateStream(request) => match do_generate_stream(&request) { - Ok(result) => { - // let result = GenerateResult { - // generated_text: result.generated_text, - // }; - // let result = serde_json::to_value(&result).unwrap(); - let result = GenerateStreamResult { - generated_text: "test".to_string(), - partial_result_token: request.params.partial_result_token, - }; - let result = serde_json::to_value(&result).unwrap(); - Response { - id: request.id, - result: Some(result), - error: None, - } - } - Err(e) => Response { - id: request.id, - result: None, - error: Some(e), - }, - }, - }; - connection - .sender - .send(Message::Response(response.clone())) - .expect("Error sending response"); - connection - .sender - .send(Message::Response(response.clone())) - .expect("Error sending response"); - connection - .sender - .send(Message::Response(response.clone())) - .expect("Error sending response"); - // connection - // .sender - // .send(Message::Response(Response { - // id: response.id, - // result: None, - // error: None, - // })) - // .expect("Error sending message"); - } - thread::sleep(std::time::Duration::from_millis(5)); - } -} diff --git a/test.json b/test.json new file mode 100644 index 0000000..58ac9e8 --- /dev/null +++ b/test.json @@ -0,0 +1,18 @@ +{ + "macos": { + "model_gguf": { + "repository": "deepseek-coder-6.7b-base", + "name": "Q4_K_M.gguf", + "fim": false, + "n_ctx": 2048, + "n_threads": 8, + "n_gpu_layers": 35 + } + }, + "linux": { + "model_gptq": { + "repository": "theblokesomething", + "name": "some q5 or something" + } + } +}