From ac09f89da7ee91df9cb6993960e5bcc58fcd4338 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 24 Nov 2023 15:10:45 -0800 Subject: [PATCH] Moving towards trait system for models and cleaned up a ton of stuff --- Cargo.lock | 79 ++++++- Cargo.toml | 12 +- run.sh | 2 +- src/main.rs | 216 +++++++++++--------- src/models/mod.rs | 32 +++ src/{transformer.rs => models/starcoder.rs} | 66 +++--- 6 files changed, 268 insertions(+), 139 deletions(-) create mode 100644 src/models/mod.rs rename src/{transformer.rs => models/starcoder.rs} (83%) diff --git a/Cargo.lock b/Cargo.lock index ffef18d..5273441 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3,10 +3,13 @@ version = 3 [[package]] -name = "accelerate-src" -version = "0.3.2" +name = "addr2line" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] [[package]] name = "adler" @@ -76,6 +79,9 @@ name = "anyhow" version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +dependencies = [ + "backtrace", +] [[package]] name = "autocfg" @@ -83,6 +89,21 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.13.1" @@ -137,11 +158,11 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" name = "candle-core" version = "0.3.1" dependencies = [ - "accelerate-src", "byteorder", + "candle-kernels", + "cudarc", "gemm", "half", - "libc", "memmap2", "num-traits", "num_cpus", @@ -154,11 +175,19 @@ dependencies = [ "zip", ] +[[package]] +name = "candle-kernels" +version = "0.3.1" +dependencies = [ + "anyhow", + "glob", + "rayon", +] + [[package]] name = "candle-nn" version = "0.3.1" dependencies = [ - "accelerate-src", "candle-core", "half", "num-traits", @@ -172,7 +201,6 @@ dependencies = [ name = "candle-transformers" version = "0.3.1" dependencies = [ - "accelerate-src", "byteorder", "candle-core", "candle-nn", @@ -334,6 +362,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "cudarc" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552" +dependencies = [ + "half", +] + [[package]] name = "darling" version = "0.14.4" @@ -636,6 +673,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "2.3.1" @@ -793,7 +842,6 @@ dependencies = [ "hf-hub", "lsp-server", "lsp-types", - "once_cell", "parking_lot", "rand", "ropey", @@ -968,6 +1016,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.18.0" @@ -1306,6 +1363,12 @@ dependencies = [ "str_indices", ] +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustix" version = "0.38.25" diff --git a/Cargo.toml b/Cargo.toml index 9e159b8..377eb31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,18 +9,20 @@ edition = "2021" anyhow = "1.0.75" lsp-server = "0.7.4" lsp-types = "0.94.1" -once_cell = "1.18.0" -parking_lot = "0.12.1" ropey = "1.6.1" serde = "1.0.190" serde_json = "1.0.108" # candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] } # candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] } # candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] } -candle-core = { path = "../candle/candle-core", version = "0.3.1", features = ["accelerate"] } -candle-nn = { path = "../candle/candle-nn", version = "0.3.1", features = ["accelerate"] } -candle-transformers = { path = "../candle/candle-transformers", version = "0.3.1", features = ["accelerate"] } +candle-core = { path = "../candle/candle-core" } +candle-nn = { path = "../candle/candle-nn" } +candle-transformers = { path = "../candle/candle-transformers" } hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" } rand = "0.8.5" tokenizers = "0.14.1" +parking_lot = "0.12.1" +[features] +default = [] +cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] diff --git a/run.sh b/run.sh index 5b27604..82f9831 100755 --- a/run.sh +++ b/run.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -/Users/silas/Projects/lsp-ai/target/release/lsp-ai +/home/silas/Projects/lsp-ai/target/release/lsp-ai diff --git a/src/main.rs b/src/main.rs index 58116d5..9bb8002 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,26 +4,32 @@ use core::panic; use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response}; use lsp_types::{ request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions, - CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, Position, Range, - RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit, + CompletionParams, CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, + Position, Range, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit, }; use parking_lot::Mutex; -use serde::Deserialize; -// use pyo3::prelude::*; -// use pyo3::types::PyTuple; use ropey::Rope; +use serde::Deserialize; use std::collections::HashMap; +use std::sync::Arc; +use std::thread; -mod transformer; - -static FILE_MAP: once_cell::sync::Lazy>> = - once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new())); +mod models; +use models::{Model, ModelParams}; // Taken directly from: https://github.com/rust-lang/rust-analyzer fn notification_is(notification: &Notification) -> bool { notification.method == N::METHOD } +fn cast(req: Request) -> Result<(RequestId, R::Params), ExtractError> +where + R: lsp_types::request::Request, + R::Params: serde::de::DeserializeOwned, +{ + req.extract(R::METHOD) +} + fn main() -> Result<()> { let (connection, io_threads) = Connection::stdio(); let server_capabilities = serde_json::to_value(&ServerCapabilities { @@ -41,15 +47,107 @@ fn main() -> Result<()> { #[derive(Deserialize)] struct Params { - model: Option, - model_file: Option, - model_type: Option, - device: Option, + // We may want to put other non-model related parameters here in the future + model_params: Option, } +struct CompletionRequest { + id: RequestId, + params: CompletionParams, + rope: Rope, +} + +impl CompletionRequest { + fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self { + Self { id, params, rope } + } +} + +// This main loop is tricky +// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get +// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { let params: Params = serde_json::from_value(params)?; - let mut text_generation = transformer::build()?; + + // Prep variables + let connection = Arc::new(connection); + let mut file_map: HashMap = HashMap::new(); + + // How we communicate between the worker and receiver threads + let last_completion_request = Arc::new(Mutex::new(None)); + + // Thread local variables + let thread_last_completion_request = last_completion_request.clone(); + let thread_connection = connection.clone(); + // We need to allow unreachabel to be able to use the question mark operators here + // We could probably restructure this to not require it + #[allow(unreachable_code)] + thread::spawn(move || { + // Build the model from the params + let mut model: Box = params.model_params.unwrap_or_default().try_into()?; + loop { + // I think we need this drop, not 100% sure though + let mut completion_request = thread_last_completion_request.lock(); + let params = std::mem::take(&mut *completion_request); + drop(completion_request); + if let Some(CompletionRequest { + id, + params, + mut rope, + }) = params + { + let filter_text = rope + .get_line(params.text_document_position.position.line as usize) + .context("Error getting line with ropey")? + .to_string(); + + // Convert rope to correct prompt for llm + let start_index = rope + .line_to_char(params.text_document_position.position.line as usize) + + params.text_document_position.position.character as usize; + rope.insert(start_index, ""); + let prompt = format!("{}", rope); + let insert_text = model.run(&prompt)?; + + // Create and return the completion + let completion_text_edit = TextEdit::new( + Range::new( + Position::new( + params.text_document_position.position.line, + params.text_document_position.position.character, + ), + Position::new( + params.text_document_position.position.line, + params.text_document_position.position.character, + ), + ), + insert_text.clone(), + ); + let item = CompletionItem { + label: format!("ai - {insert_text}"), + filter_text: Some(filter_text), + text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)), + kind: Some(CompletionItemKind::TEXT), + ..Default::default() + }; + let completion_list = CompletionList { + is_incomplete: false, + items: vec![item], + }; + let result = Some(CompletionResponse::List(completion_list)); + let result = serde_json::to_value(&result).unwrap(); + let resp = Response { + id, + result: Some(result), + error: None, + }; + thread_connection.sender.send(Message::Response(resp))?; + } + thread::sleep(std::time::Duration::from_millis(5)); + } + anyhow::Ok(()) + }); + for msg in &connection.receiver { match msg { Message::Request(req) => { @@ -59,59 +157,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { match cast::(req) { Ok((id, params)) => { // Get rope - let file_map = FILE_MAP.lock(); - let mut rope = file_map + let rope = file_map .get(params.text_document_position.text_document.uri.as_str()) .context("Error file not found")? .clone(); - let filter_text = rope - .get_line(params.text_document_position.position.line as usize) - .context("Error getting line with ropey")? - .to_string(); - - // Convert rope to correct prompt for llm - let start_index = rope - .line_to_char(params.text_document_position.position.line as usize) - + params.text_document_position.position.character as usize; - rope.insert(start_index, ""); - let prompt = format!("{}", rope); - let insert_text = text_generation.run(&prompt, 64)?; - - // Create and return the completion - let completion_text_edit = TextEdit::new( - Range::new( - Position::new( - params.text_document_position.position.line, - params.text_document_position.position.character, - ), - Position::new( - params.text_document_position.position.line, - params.text_document_position.position.character, - ), - ), - insert_text.clone(), - ); - let item = CompletionItem { - label: format!("ai - {insert_text}"), - filter_text: Some(filter_text), - text_edit: Some(lsp_types::CompletionTextEdit::Edit( - completion_text_edit, - )), - kind: Some(CompletionItemKind::TEXT), - ..Default::default() - }; - let completion_list = CompletionList { - is_incomplete: false, - items: vec![item], - }; - let result = Some(CompletionResponse::List(completion_list)); - let result = serde_json::to_value(&result).unwrap(); - let resp = Response { - id, - result: Some(result), - error: None, - }; - connection.sender.send(Message::Response(resp))?; + // Update the last CompletionRequest + let mut lcr = last_completion_request.lock(); + *lcr = Some(CompletionRequest::new(id, params, rope)); continue; } Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"), @@ -123,11 +175,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { if notification_is::(¬) { let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?; let rope = Rope::from_str(¶ms.text_document.text); - let mut file_map = FILE_MAP.lock(); file_map.insert(params.text_document.uri.to_string(), rope); } else if notification_is::(¬) { let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?; - let mut file_map = FILE_MAP.lock(); let rope = file_map .get_mut(params.text_document.uri.as_str()) .context("Error trying to get file that does not exist")?; @@ -146,7 +196,6 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { } } else if notification_is::(¬) { let params: RenameFilesParams = serde_json::from_value(not.params)?; - let mut file_map = FILE_MAP.lock(); for file_rename in params.files { if let Some(rope) = file_map.remove(&file_rename.old_uri) { file_map.insert(file_rename.new_uri, rope); @@ -159,36 +208,3 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { } Ok(()) } - -fn cast(req: Request) -> Result<(RequestId, R::Params), ExtractError> -where - R: lsp_types::request::Request, - R::Params: serde::de::DeserializeOwned, -{ - req.extract(R::METHOD) -} - -// #[cfg(test)] -// mod tests { -// use super::*; - -// #[test] -// fn test_lsp() -> Result<()> { -// let prompt = "def sum_two_numers(x: int, y:"; -// let result = Python::with_gil(|py| -> Result { -// let transform: Py = PY_MODULE -// .as_ref() -// .expect("Error getting python module") -// .getattr(py, "transform") -// .expect("Error getting transform"); - -// let output = transform -// .call1(py, PyTuple::new(py, &[prompt])) -// .expect("Error calling transform"); - -// Ok(output.extract(py).expect("Error extracting result")) -// })?; -// println!("\n\nTHE RESULT\n{:?}\n\n", result); -// Ok(()) -// } -// } diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..5a691bf --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use serde::Deserialize; + +mod starcoder; + +pub trait Model { + fn run(&mut self, prompt: &str) -> Result; +} + +#[derive(Deserialize, Default)] +pub struct ModelParams { + model: Option, + model_file: Option, + model_type: Option, + max_length: Option, +} + +impl TryFrom for Box { + type Error = anyhow::Error; + + fn try_from(value: ModelParams) -> Result { + let model_type = value.model_type.unwrap_or("starcoder".to_string()); + let max_length = value.max_length.unwrap_or(12); + Ok(Box::new(match model_type.as_str() { + "starcoder" => starcoder::build_model(value.model, value.model_file, max_length)?, + _ => anyhow::bail!( + "Model type: {} not supported. Feel free to make a pr or create a github issue.", + model_type + ), + })) + } +} diff --git a/src/transformer.rs b/src/models/starcoder.rs similarity index 83% rename from src/transformer.rs rename to src/models/starcoder.rs index 79e4a30..2cc79c6 100644 --- a/src/transformer.rs +++ b/src/models/starcoder.rs @@ -6,32 +6,16 @@ use candle_transformers::models::bigcode::{Config, GPTBigCode}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use tokenizers::Tokenizer; -pub struct TextGeneration { +pub struct Model { model: GPTBigCode, device: Device, tokenizer: Tokenizer, logits_processor: LogitsProcessor, + max_length: usize, } -impl TextGeneration { - fn new( - model: GPTBigCode, - tokenizer: Tokenizer, - seed: u64, - temp: Option, - top_p: Option, - device: &Device, - ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); - Self { - model, - tokenizer, - logits_processor, - device: device.clone(), - } - } - - pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result { +impl super::Model for Model { + fn run(&mut self, prompt: &str) -> Result { eprintln!("Starting to generate tokens"); let mut tokens = self .tokenizer @@ -42,7 +26,7 @@ impl TextGeneration { let mut new_tokens = vec![]; let mut outputs = vec![]; let start_gen = std::time::Instant::now(); - for index in 0..sample_len { + for index in 0..self.max_length { let (context_size, past_len) = if self.model.config().use_cache && index > 0 { (1, tokens.len().saturating_sub(1)) } else { @@ -62,15 +46,40 @@ impl TextGeneration { let dt = start_gen.elapsed(); self.model.clear_cache(); eprintln!( - "GENERATED {} tokens in {} seconds", + "GENERATED {} tokens in {} milliseconds", outputs.len(), - dt.as_secs() + dt.as_millis() ); Ok(outputs.join("")) } } -pub fn build() -> Result { +impl Model { + fn new( + model: GPTBigCode, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + device: &Device, + max_length: usize, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + device: device.clone(), + max_length, + } + } +} + +pub fn build_model( + model: Option, + model_file: Option, + max_length: usize, +) -> Result { let start = std::time::Instant::now(); eprintln!("Loading in model"); let api = ApiBuilder::new() @@ -87,17 +96,24 @@ pub fn build() -> Result { .map(|f| repo.get(f)) .collect::, _>>()?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + // Set the device + #[cfg(feature = "cuda")] + let device = Device::new_cuda(0)?; + #[cfg(not(feature = "cuda"))] let device = Device::Cpu; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let config = Config::starcoder_1b(); let model = GPTBigCode::load(vb, config)?; eprintln!("loaded the model in {:?}", start.elapsed()); - Ok(TextGeneration::new( + Ok(Model::new( model, tokenizer, 0, Some(0.85), None, &device, + max_length, )) }