diff --git a/Cargo.lock b/Cargo.lock index d442fc5..3246022 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1409,8 +1409,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama-cpp-2" -version = "0.1.34" -source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" +version = "0.1.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f391a790923a78bbe6100824124492c7df3d17b26340424eb813b88a521707a3" dependencies = [ "llama-cpp-sys-2", "thiserror", @@ -1419,8 +1420,9 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" -version = "0.1.34" -source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" +version = "0.1.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1f26aac755fc36d5cc19f0853c4d359db8c4e4e5944705a27c7cce5e2cb9c36" dependencies = [ "bindgen", "cc", diff --git a/Cargo.toml b/Cargo.toml index 0756298..462777b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,7 @@ tokenizers = "0.14.1" parking_lot = "0.12.1" once_cell = "1.19.0" directories = "5.0.1" -# llama-cpp-2 = "0.1.31" -llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" } +llama-cpp-2 = "0.1.47" minijinja = { version = "1.0.12", features = ["loader"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing = "0.1.40" diff --git a/src/configuration.rs b/src/configuration.rs index 4148078..90c3b4a 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -97,6 +97,10 @@ pub struct Model { pub name: Option, } +const fn llamacpp_max_requests_per_second_default() -> f32 { + 0.25 +} + #[derive(Clone, Debug, Deserialize)] pub struct LLaMACPP { // The model to use @@ -109,6 +113,9 @@ pub struct LLaMACPP { pub max_tokens: MaxTokens, // Chat args pub chat: Option, + // The maximum requests per second + #[serde(default = "llamacpp_max_requests_per_second_default")] + pub max_requests_per_second: f32, // Kwargs passed to LlamaCPP #[serde(flatten)] pub kwargs: Kwargs, @@ -128,28 +135,33 @@ impl Default for LLaMACPP { }), max_tokens: MaxTokens::default(), chat: None, + max_requests_per_second: f32::MAX, kwargs: Kwargs::default(), } } } +const fn api_max_requests_per_second_default() -> f32 { + 0.5 +} + const fn openai_top_p_default() -> f32 { 0.95 } -const fn openai_presence_penalty() -> f32 { +const fn openai_presence_penalty_default() -> f32 { 0. } -const fn openai_frequency_penalty() -> f32 { +const fn openai_frequency_penalty_default() -> f32 { 0. } -const fn openai_temperature() -> f32 { +const fn openai_temperature_default() -> f32 { 0.1 } -const fn openai_max_context() -> usize { +const fn openai_max_context_default() -> usize { DEFAULT_OPENAI_MAX_CONTEXT } @@ -162,6 +174,9 @@ pub struct OpenAI { pub completions_endpoint: Option, // The chat endpoint pub chat_endpoint: Option, + // The maximum requests per second + #[serde(default = "api_max_requests_per_second_default")] + pub max_requests_per_second: f32, // The model name pub model: String, // Fill in the middle support @@ -174,13 +189,13 @@ pub struct OpenAI { // Other available args #[serde(default = "openai_top_p_default")] pub top_p: f32, - #[serde(default = "openai_presence_penalty")] + #[serde(default = "openai_presence_penalty_default")] pub presence_penalty: f32, - #[serde(default = "openai_frequency_penalty")] + #[serde(default = "openai_frequency_penalty_default")] pub frequency_penalty: f32, - #[serde(default = "openai_temperature")] + #[serde(default = "openai_temperature_default")] pub temperature: f32, - #[serde(default = "openai_max_context")] + #[serde(default = "openai_max_context_default")] max_context: usize, } @@ -193,6 +208,9 @@ pub struct Anthropic { pub completions_endpoint: Option, // The chat endpoint pub chat_endpoint: Option, + // The maximum requests per second + #[serde(default = "api_max_requests_per_second_default")] + pub max_requests_per_second: f32, // The model name pub model: String, // The maximum number of new tokens to generate @@ -203,15 +221,17 @@ pub struct Anthropic { // System prompt #[serde(default = "openai_top_p_default")] pub top_p: f32, - #[serde(default = "openai_temperature")] + #[serde(default = "openai_temperature_default")] pub temperature: f32, - #[serde(default = "openai_max_context")] + #[serde(default = "openai_max_context_default")] max_context: usize, } #[derive(Clone, Debug, Deserialize, Default)] pub struct ValidConfiguration { + #[serde(default)] pub memory: ValidMemoryBackend, + #[serde(default)] pub transformer: ValidTransformerBackend, } @@ -246,9 +266,17 @@ impl Configuration { } /////////////////////////////////////// - // Helpers for the Memory Backend ///// + // Helpers for the backends /////////// /////////////////////////////////////// + pub fn get_transformer_max_requests_per_second(&self) -> f32 { + match &self.config.transformer { + ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp.max_requests_per_second, + ValidTransformerBackend::OpenAI(openai) => openai.max_requests_per_second, + ValidTransformerBackend::Anthropic(anthropic) => anthropic.max_requests_per_second, + } + } + pub fn get_max_context_length(&self) -> usize { match &self.config.transformer { ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp diff --git a/src/main.rs b/src/main.rs index 291788c..a6f9427 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,6 @@ use lsp_types::{ request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, }; -use parking_lot::Mutex; use std::{ sync::{mpsc, Arc}, thread, @@ -75,33 +74,35 @@ fn main() -> Result<()> { fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { // Build our configuration - let configuration = Configuration::new(args)?; + let config = Configuration::new(args)?; // Wrap the connection for sharing between threads let connection = Arc::new(connection); // Our channel we use to communicate with our transformer worker - let last_worker_request = Arc::new(Mutex::new(None)); + // let last_worker_request = Arc::new(Mutex::new(None)); + let (transformer_tx, transformer_rx) = mpsc::channel(); // The channel we use to communicate with our memory worker let (memory_tx, memory_rx) = mpsc::channel(); // Setup the transformer worker - let memory_backend: Box = configuration.clone().try_into()?; + let memory_backend: Box = config.clone().try_into()?; thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); // Setup our transformer worker let transformer_backend: Box = - configuration.clone().try_into()?; - let thread_last_worker_request = last_worker_request.clone(); + config.clone().try_into()?; let thread_connection = connection.clone(); let thread_memory_tx = memory_tx.clone(); + let thread_config = config.clone(); thread::spawn(move || { transformer_worker::run( transformer_backend, thread_memory_tx, - thread_last_worker_request, + transformer_rx, thread_connection, + thread_config, ) }); @@ -111,33 +112,28 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { if connection.handle_shutdown(&req)? { return Ok(()); } - // Right now each if / else basically does the same thing, - // but this may change soon so it is worth making it a little - // more verbose than it needs to be now if request_is::(&req) { match cast::(req) { Ok((id, params)) => { - let mut lcr = last_worker_request.lock(); let completion_request = CompletionRequest::new(id, params); - *lcr = Some(WorkerRequest::Completion(completion_request)); + transformer_tx.send(WorkerRequest::Completion(completion_request))?; } Err(err) => error!("{err:?}"), } } else if request_is::(&req) { match cast::(req) { Ok((id, params)) => { - let mut lcr = last_worker_request.lock(); - let completion_request = GenerateRequest::new(id, params); - *lcr = Some(WorkerRequest::Generate(completion_request)); + let generate_request = GenerateRequest::new(id, params); + transformer_tx.send(WorkerRequest::Generate(generate_request))?; } Err(err) => error!("{err:?}"), } } else if request_is::(&req) { match cast::(req) { Ok((id, params)) => { - let mut lcr = last_worker_request.lock(); - let completion_request = GenerateStreamRequest::new(id, params); - *lcr = Some(WorkerRequest::GenerateStream(completion_request)); + let generate_stream_request = GenerateStreamRequest::new(id, params); + transformer_tx + .send(WorkerRequest::GenerateStream(generate_stream_request))?; } Err(err) => error!("{err:?}"), } diff --git a/src/transformer_worker.rs b/src/transformer_worker.rs index a738ed5..7b68965 100644 --- a/src/transformer_worker.rs +++ b/src/transformer_worker.rs @@ -3,11 +3,12 @@ use lsp_types::{ CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, Position, Range, TextEdit, }; -use parking_lot::Mutex; -use std::{sync::Arc, thread}; +use std::sync::Arc; +use std::time::SystemTime; use tokio::sync::oneshot; -use tracing::error; +use tracing::{debug, error, instrument}; +use crate::configuration::Configuration; use crate::custom_requests::generate::{GenerateParams, GenerateResult}; use crate::custom_requests::generate_stream::GenerateStreamParams; use crate::memory_backends::PromptForType; @@ -53,7 +54,7 @@ impl GenerateStreamRequest { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum WorkerRequest { Completion(CompletionRequest), Generate(GenerateRequest), @@ -72,6 +73,7 @@ pub struct DoGenerateStreamResponse { pub generated_text: String, } +#[instrument(skip(transformer_backend, memory_backend_tx, connection))] async fn do_task( transformer_backend: Arc>, memory_backend_tx: std::sync::mpsc::Sender, @@ -100,21 +102,22 @@ async fn do_task( } } WorkerRequest::GenerateStream(_) => { - panic!("Streaming is not supported yet") + panic!("Streaming is not yet supported") } }; connection .sender .send(Message::Response(response)) - .expect("Error sending message"); + .expect("Error sending response"); Ok(()) } fn do_run( transformer_backend: Box, memory_backend_tx: std::sync::mpsc::Sender, - last_worker_request: Arc>>, + transformer_rx: std::sync::mpsc::Receiver, connection: Arc, + config: Configuration, ) -> anyhow::Result<()> { let transformer_backend = Arc::new(transformer_backend); let runtime = tokio::runtime::Builder::new_multi_thread() @@ -122,43 +125,55 @@ fn do_run( .enable_all() .build()?; + // This logic is not perfect, but works well enough for now + let max_requests_per_second = config.get_transformer_max_requests_per_second(); + let mut first_request = SystemTime::now(); + let mut requests_in_last_5_seconds = 0.; + loop { - let option_worker_request: Option = { - let mut completion_request = last_worker_request.lock(); - std::mem::take(&mut *completion_request) - }; - if let Some(request) = option_worker_request { - let thread_transformer_backend = transformer_backend.clone(); - let thread_memory_backend_tx = memory_backend_tx.clone(); - let thread_connection = connection.clone(); - runtime.spawn(async move { - if let Err(e) = do_task( - thread_transformer_backend, - thread_memory_backend_tx, - request, - thread_connection, - ) - .await - { - error!("error in transformer worker task: {e}") - } - }); + let request = transformer_rx.recv()?; + + if first_request.elapsed()?.as_secs() > 5 { + first_request = SystemTime::now(); + requests_in_last_5_seconds = 0.; } - thread::sleep(std::time::Duration::from_millis(5)); + if requests_in_last_5_seconds / 5. > max_requests_per_second { + debug!("rate limiting transform request"); + continue; + } + requests_in_last_5_seconds += 1.; + + let thread_transformer_backend = transformer_backend.clone(); + let thread_memory_backend_tx = memory_backend_tx.clone(); + let thread_connection = connection.clone(); + runtime.spawn(async move { + if let Err(e) = do_task( + thread_transformer_backend, + thread_memory_backend_tx, + request, + thread_connection, + ) + .await + { + error!("transformer worker task: {e}") + } + }); } } pub fn run( transformer_backend: Box, - memory_backend_tx: std::sync::mpsc::Sender, - last_worker_request: Arc>>, + memory_tx: std::sync::mpsc::Sender, + transformer_rx: std::sync::mpsc::Receiver, connection: Arc, + config: Configuration, ) { if let Err(e) = do_run( transformer_backend, - memory_backend_tx, - last_worker_request, + memory_tx, + transformer_rx, connection, + config, ) { error!("error in transformer worker: {e}") }