Bump llamacpp-2 version and general configuration cleanups

This commit is contained in:
SilasMarvin
2024-04-06 14:43:17 -07:00
parent f921203aa6
commit 25b312fa2a
5 changed files with 107 additions and 67 deletions

10
Cargo.lock generated
View File

@@ -1409,8 +1409,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]] [[package]]
name = "llama-cpp-2" name = "llama-cpp-2"
version = "0.1.34" version = "0.1.47"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f391a790923a78bbe6100824124492c7df3d17b26340424eb813b88a521707a3"
dependencies = [ dependencies = [
"llama-cpp-sys-2", "llama-cpp-sys-2",
"thiserror", "thiserror",
@@ -1419,8 +1420,9 @@ dependencies = [
[[package]] [[package]]
name = "llama-cpp-sys-2" name = "llama-cpp-sys-2"
version = "0.1.34" version = "0.1.47"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1f26aac755fc36d5cc19f0853c4d359db8c4e4e5944705a27c7cce5e2cb9c36"
dependencies = [ dependencies = [
"bindgen", "bindgen",
"cc", "cc",

View File

@@ -19,8 +19,7 @@ tokenizers = "0.14.1"
parking_lot = "0.12.1" parking_lot = "0.12.1"
once_cell = "1.19.0" once_cell = "1.19.0"
directories = "5.0.1" directories = "5.0.1"
# llama-cpp-2 = "0.1.31" llama-cpp-2 = "0.1.47"
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
minijinja = { version = "1.0.12", features = ["loader"] } minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40" tracing = "0.1.40"

View File

@@ -97,6 +97,10 @@ pub struct Model {
pub name: Option<String>, pub name: Option<String>,
} }
const fn llamacpp_max_requests_per_second_default() -> f32 {
0.25
}
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct LLaMACPP { pub struct LLaMACPP {
// The model to use // The model to use
@@ -109,6 +113,9 @@ pub struct LLaMACPP {
pub max_tokens: MaxTokens, pub max_tokens: MaxTokens,
// Chat args // Chat args
pub chat: Option<Chat>, pub chat: Option<Chat>,
// The maximum requests per second
#[serde(default = "llamacpp_max_requests_per_second_default")]
pub max_requests_per_second: f32,
// Kwargs passed to LlamaCPP // Kwargs passed to LlamaCPP
#[serde(flatten)] #[serde(flatten)]
pub kwargs: Kwargs, pub kwargs: Kwargs,
@@ -128,28 +135,33 @@ impl Default for LLaMACPP {
}), }),
max_tokens: MaxTokens::default(), max_tokens: MaxTokens::default(),
chat: None, chat: None,
max_requests_per_second: f32::MAX,
kwargs: Kwargs::default(), kwargs: Kwargs::default(),
} }
} }
} }
const fn api_max_requests_per_second_default() -> f32 {
0.5
}
const fn openai_top_p_default() -> f32 { const fn openai_top_p_default() -> f32 {
0.95 0.95
} }
const fn openai_presence_penalty() -> f32 { const fn openai_presence_penalty_default() -> f32 {
0. 0.
} }
const fn openai_frequency_penalty() -> f32 { const fn openai_frequency_penalty_default() -> f32 {
0. 0.
} }
const fn openai_temperature() -> f32 { const fn openai_temperature_default() -> f32 {
0.1 0.1
} }
const fn openai_max_context() -> usize { const fn openai_max_context_default() -> usize {
DEFAULT_OPENAI_MAX_CONTEXT DEFAULT_OPENAI_MAX_CONTEXT
} }
@@ -162,6 +174,9 @@ pub struct OpenAI {
pub completions_endpoint: Option<String>, pub completions_endpoint: Option<String>,
// The chat endpoint // The chat endpoint
pub chat_endpoint: Option<String>, pub chat_endpoint: Option<String>,
// The maximum requests per second
#[serde(default = "api_max_requests_per_second_default")]
pub max_requests_per_second: f32,
// The model name // The model name
pub model: String, pub model: String,
// Fill in the middle support // Fill in the middle support
@@ -174,13 +189,13 @@ pub struct OpenAI {
// Other available args // Other available args
#[serde(default = "openai_top_p_default")] #[serde(default = "openai_top_p_default")]
pub top_p: f32, pub top_p: f32,
#[serde(default = "openai_presence_penalty")] #[serde(default = "openai_presence_penalty_default")]
pub presence_penalty: f32, pub presence_penalty: f32,
#[serde(default = "openai_frequency_penalty")] #[serde(default = "openai_frequency_penalty_default")]
pub frequency_penalty: f32, pub frequency_penalty: f32,
#[serde(default = "openai_temperature")] #[serde(default = "openai_temperature_default")]
pub temperature: f32, pub temperature: f32,
#[serde(default = "openai_max_context")] #[serde(default = "openai_max_context_default")]
max_context: usize, max_context: usize,
} }
@@ -193,6 +208,9 @@ pub struct Anthropic {
pub completions_endpoint: Option<String>, pub completions_endpoint: Option<String>,
// The chat endpoint // The chat endpoint
pub chat_endpoint: Option<String>, pub chat_endpoint: Option<String>,
// The maximum requests per second
#[serde(default = "api_max_requests_per_second_default")]
pub max_requests_per_second: f32,
// The model name // The model name
pub model: String, pub model: String,
// The maximum number of new tokens to generate // The maximum number of new tokens to generate
@@ -203,15 +221,17 @@ pub struct Anthropic {
// System prompt // System prompt
#[serde(default = "openai_top_p_default")] #[serde(default = "openai_top_p_default")]
pub top_p: f32, pub top_p: f32,
#[serde(default = "openai_temperature")] #[serde(default = "openai_temperature_default")]
pub temperature: f32, pub temperature: f32,
#[serde(default = "openai_max_context")] #[serde(default = "openai_max_context_default")]
max_context: usize, max_context: usize,
} }
#[derive(Clone, Debug, Deserialize, Default)] #[derive(Clone, Debug, Deserialize, Default)]
pub struct ValidConfiguration { pub struct ValidConfiguration {
#[serde(default)]
pub memory: ValidMemoryBackend, pub memory: ValidMemoryBackend,
#[serde(default)]
pub transformer: ValidTransformerBackend, pub transformer: ValidTransformerBackend,
} }
@@ -246,9 +266,17 @@ impl Configuration {
} }
/////////////////////////////////////// ///////////////////////////////////////
// Helpers for the Memory Backend ///// // Helpers for the backends ///////////
/////////////////////////////////////// ///////////////////////////////////////
pub fn get_transformer_max_requests_per_second(&self) -> f32 {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp.max_requests_per_second,
ValidTransformerBackend::OpenAI(openai) => openai.max_requests_per_second,
ValidTransformerBackend::Anthropic(anthropic) => anthropic.max_requests_per_second,
}
}
pub fn get_max_context_length(&self) -> usize { pub fn get_max_context_length(&self) -> usize {
match &self.config.transformer { match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp

View File

@@ -5,7 +5,6 @@ use lsp_types::{
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams, request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
}; };
use parking_lot::Mutex;
use std::{ use std::{
sync::{mpsc, Arc}, sync::{mpsc, Arc},
thread, thread,
@@ -75,33 +74,35 @@ fn main() -> Result<()> {
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
// Build our configuration // Build our configuration
let configuration = Configuration::new(args)?; let config = Configuration::new(args)?;
// Wrap the connection for sharing between threads // Wrap the connection for sharing between threads
let connection = Arc::new(connection); let connection = Arc::new(connection);
// Our channel we use to communicate with our transformer worker // Our channel we use to communicate with our transformer worker
let last_worker_request = Arc::new(Mutex::new(None)); // let last_worker_request = Arc::new(Mutex::new(None));
let (transformer_tx, transformer_rx) = mpsc::channel();
// The channel we use to communicate with our memory worker // The channel we use to communicate with our memory worker
let (memory_tx, memory_rx) = mpsc::channel(); let (memory_tx, memory_rx) = mpsc::channel();
// Setup the transformer worker // Setup the transformer worker
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = configuration.clone().try_into()?; let memory_backend: Box<dyn MemoryBackend + Send + Sync> = config.clone().try_into()?;
thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
// Setup our transformer worker // Setup our transformer worker
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> = let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
configuration.clone().try_into()?; config.clone().try_into()?;
let thread_last_worker_request = last_worker_request.clone();
let thread_connection = connection.clone(); let thread_connection = connection.clone();
let thread_memory_tx = memory_tx.clone(); let thread_memory_tx = memory_tx.clone();
let thread_config = config.clone();
thread::spawn(move || { thread::spawn(move || {
transformer_worker::run( transformer_worker::run(
transformer_backend, transformer_backend,
thread_memory_tx, thread_memory_tx,
thread_last_worker_request, transformer_rx,
thread_connection, thread_connection,
thread_config,
) )
}); });
@@ -111,33 +112,28 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
if connection.handle_shutdown(&req)? { if connection.handle_shutdown(&req)? {
return Ok(()); return Ok(());
} }
// Right now each if / else basically does the same thing,
// but this may change soon so it is worth making it a little
// more verbose than it needs to be now
if request_is::<Completion>(&req) { if request_is::<Completion>(&req) {
match cast::<Completion>(req) { match cast::<Completion>(req) {
Ok((id, params)) => { Ok((id, params)) => {
let mut lcr = last_worker_request.lock();
let completion_request = CompletionRequest::new(id, params); let completion_request = CompletionRequest::new(id, params);
*lcr = Some(WorkerRequest::Completion(completion_request)); transformer_tx.send(WorkerRequest::Completion(completion_request))?;
} }
Err(err) => error!("{err:?}"), Err(err) => error!("{err:?}"),
} }
} else if request_is::<Generate>(&req) { } else if request_is::<Generate>(&req) {
match cast::<Generate>(req) { match cast::<Generate>(req) {
Ok((id, params)) => { Ok((id, params)) => {
let mut lcr = last_worker_request.lock(); let generate_request = GenerateRequest::new(id, params);
let completion_request = GenerateRequest::new(id, params); transformer_tx.send(WorkerRequest::Generate(generate_request))?;
*lcr = Some(WorkerRequest::Generate(completion_request));
} }
Err(err) => error!("{err:?}"), Err(err) => error!("{err:?}"),
} }
} else if request_is::<GenerateStream>(&req) { } else if request_is::<GenerateStream>(&req) {
match cast::<GenerateStream>(req) { match cast::<GenerateStream>(req) {
Ok((id, params)) => { Ok((id, params)) => {
let mut lcr = last_worker_request.lock(); let generate_stream_request = GenerateStreamRequest::new(id, params);
let completion_request = GenerateStreamRequest::new(id, params); transformer_tx
*lcr = Some(WorkerRequest::GenerateStream(completion_request)); .send(WorkerRequest::GenerateStream(generate_stream_request))?;
} }
Err(err) => error!("{err:?}"), Err(err) => error!("{err:?}"),
} }

View File

@@ -3,11 +3,12 @@ use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit, Position, Range, TextEdit,
}; };
use parking_lot::Mutex; use std::sync::Arc;
use std::{sync::Arc, thread}; use std::time::SystemTime;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::error; use tracing::{debug, error, instrument};
use crate::configuration::Configuration;
use crate::custom_requests::generate::{GenerateParams, GenerateResult}; use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::GenerateStreamParams; use crate::custom_requests::generate_stream::GenerateStreamParams;
use crate::memory_backends::PromptForType; use crate::memory_backends::PromptForType;
@@ -53,7 +54,7 @@ impl GenerateStreamRequest {
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub enum WorkerRequest { pub enum WorkerRequest {
Completion(CompletionRequest), Completion(CompletionRequest),
Generate(GenerateRequest), Generate(GenerateRequest),
@@ -72,6 +73,7 @@ pub struct DoGenerateStreamResponse {
pub generated_text: String, pub generated_text: String,
} }
#[instrument(skip(transformer_backend, memory_backend_tx, connection))]
async fn do_task( async fn do_task(
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>, transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>, memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
@@ -100,21 +102,22 @@ async fn do_task(
} }
} }
WorkerRequest::GenerateStream(_) => { WorkerRequest::GenerateStream(_) => {
panic!("Streaming is not supported yet") panic!("Streaming is not yet supported")
} }
}; };
connection connection
.sender .sender
.send(Message::Response(response)) .send(Message::Response(response))
.expect("Error sending message"); .expect("Error sending response");
Ok(()) Ok(())
} }
fn do_run( fn do_run(
transformer_backend: Box<dyn TransformerBackend + Send + Sync>, transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>, memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>, transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
connection: Arc<Connection>, connection: Arc<Connection>,
config: Configuration,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let transformer_backend = Arc::new(transformer_backend); let transformer_backend = Arc::new(transformer_backend);
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -122,43 +125,55 @@ fn do_run(
.enable_all() .enable_all()
.build()?; .build()?;
// This logic is not perfect, but works well enough for now
let max_requests_per_second = config.get_transformer_max_requests_per_second();
let mut first_request = SystemTime::now();
let mut requests_in_last_5_seconds = 0.;
loop { loop {
let option_worker_request: Option<WorkerRequest> = { let request = transformer_rx.recv()?;
let mut completion_request = last_worker_request.lock();
std::mem::take(&mut *completion_request) if first_request.elapsed()?.as_secs() > 5 {
}; first_request = SystemTime::now();
if let Some(request) = option_worker_request { requests_in_last_5_seconds = 0.;
let thread_transformer_backend = transformer_backend.clone();
let thread_memory_backend_tx = memory_backend_tx.clone();
let thread_connection = connection.clone();
runtime.spawn(async move {
if let Err(e) = do_task(
thread_transformer_backend,
thread_memory_backend_tx,
request,
thread_connection,
)
.await
{
error!("error in transformer worker task: {e}")
}
});
} }
thread::sleep(std::time::Duration::from_millis(5)); if requests_in_last_5_seconds / 5. > max_requests_per_second {
debug!("rate limiting transform request");
continue;
}
requests_in_last_5_seconds += 1.;
let thread_transformer_backend = transformer_backend.clone();
let thread_memory_backend_tx = memory_backend_tx.clone();
let thread_connection = connection.clone();
runtime.spawn(async move {
if let Err(e) = do_task(
thread_transformer_backend,
thread_memory_backend_tx,
request,
thread_connection,
)
.await
{
error!("transformer worker task: {e}")
}
});
} }
} }
pub fn run( pub fn run(
transformer_backend: Box<dyn TransformerBackend + Send + Sync>, transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>, memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>, transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
connection: Arc<Connection>, connection: Arc<Connection>,
config: Configuration,
) { ) {
if let Err(e) = do_run( if let Err(e) = do_run(
transformer_backend, transformer_backend,
memory_backend_tx, memory_tx,
last_worker_request, transformer_rx,
connection, connection,
config,
) { ) {
error!("error in transformer worker: {e}") error!("error in transformer worker: {e}")
} }