Bump llamacpp-2 version and general configuration cleanups

This commit is contained in:
SilasMarvin
2024-04-06 14:43:17 -07:00
parent f921203aa6
commit 25b312fa2a
5 changed files with 107 additions and 67 deletions

10
Cargo.lock generated
View File

@@ -1409,8 +1409,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "llama-cpp-2"
version = "0.1.34"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
version = "0.1.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f391a790923a78bbe6100824124492c7df3d17b26340424eb813b88a521707a3"
dependencies = [
"llama-cpp-sys-2",
"thiserror",
@@ -1419,8 +1420,9 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.34"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
version = "0.1.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1f26aac755fc36d5cc19f0853c4d359db8c4e4e5944705a27c7cce5e2cb9c36"
dependencies = [
"bindgen",
"cc",

View File

@@ -19,8 +19,7 @@ tokenizers = "0.14.1"
parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
# llama-cpp-2 = "0.1.31"
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
llama-cpp-2 = "0.1.47"
minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"

View File

@@ -97,6 +97,10 @@ pub struct Model {
pub name: Option<String>,
}
const fn llamacpp_max_requests_per_second_default() -> f32 {
0.25
}
#[derive(Clone, Debug, Deserialize)]
pub struct LLaMACPP {
// The model to use
@@ -109,6 +113,9 @@ pub struct LLaMACPP {
pub max_tokens: MaxTokens,
// Chat args
pub chat: Option<Chat>,
// The maximum requests per second
#[serde(default = "llamacpp_max_requests_per_second_default")]
pub max_requests_per_second: f32,
// Kwargs passed to LlamaCPP
#[serde(flatten)]
pub kwargs: Kwargs,
@@ -128,28 +135,33 @@ impl Default for LLaMACPP {
}),
max_tokens: MaxTokens::default(),
chat: None,
max_requests_per_second: f32::MAX,
kwargs: Kwargs::default(),
}
}
}
const fn api_max_requests_per_second_default() -> f32 {
0.5
}
const fn openai_top_p_default() -> f32 {
0.95
}
const fn openai_presence_penalty() -> f32 {
const fn openai_presence_penalty_default() -> f32 {
0.
}
const fn openai_frequency_penalty() -> f32 {
const fn openai_frequency_penalty_default() -> f32 {
0.
}
const fn openai_temperature() -> f32 {
const fn openai_temperature_default() -> f32 {
0.1
}
const fn openai_max_context() -> usize {
const fn openai_max_context_default() -> usize {
DEFAULT_OPENAI_MAX_CONTEXT
}
@@ -162,6 +174,9 @@ pub struct OpenAI {
pub completions_endpoint: Option<String>,
// The chat endpoint
pub chat_endpoint: Option<String>,
// The maximum requests per second
#[serde(default = "api_max_requests_per_second_default")]
pub max_requests_per_second: f32,
// The model name
pub model: String,
// Fill in the middle support
@@ -174,13 +189,13 @@ pub struct OpenAI {
// Other available args
#[serde(default = "openai_top_p_default")]
pub top_p: f32,
#[serde(default = "openai_presence_penalty")]
#[serde(default = "openai_presence_penalty_default")]
pub presence_penalty: f32,
#[serde(default = "openai_frequency_penalty")]
#[serde(default = "openai_frequency_penalty_default")]
pub frequency_penalty: f32,
#[serde(default = "openai_temperature")]
#[serde(default = "openai_temperature_default")]
pub temperature: f32,
#[serde(default = "openai_max_context")]
#[serde(default = "openai_max_context_default")]
max_context: usize,
}
@@ -193,6 +208,9 @@ pub struct Anthropic {
pub completions_endpoint: Option<String>,
// The chat endpoint
pub chat_endpoint: Option<String>,
// The maximum requests per second
#[serde(default = "api_max_requests_per_second_default")]
pub max_requests_per_second: f32,
// The model name
pub model: String,
// The maximum number of new tokens to generate
@@ -203,15 +221,17 @@ pub struct Anthropic {
// System prompt
#[serde(default = "openai_top_p_default")]
pub top_p: f32,
#[serde(default = "openai_temperature")]
#[serde(default = "openai_temperature_default")]
pub temperature: f32,
#[serde(default = "openai_max_context")]
#[serde(default = "openai_max_context_default")]
max_context: usize,
}
#[derive(Clone, Debug, Deserialize, Default)]
pub struct ValidConfiguration {
#[serde(default)]
pub memory: ValidMemoryBackend,
#[serde(default)]
pub transformer: ValidTransformerBackend,
}
@@ -246,9 +266,17 @@ impl Configuration {
}
///////////////////////////////////////
// Helpers for the Memory Backend /////
// Helpers for the backends ///////////
///////////////////////////////////////
pub fn get_transformer_max_requests_per_second(&self) -> f32 {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp.max_requests_per_second,
ValidTransformerBackend::OpenAI(openai) => openai.max_requests_per_second,
ValidTransformerBackend::Anthropic(anthropic) => anthropic.max_requests_per_second,
}
}
pub fn get_max_context_length(&self) -> usize {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp

View File

@@ -5,7 +5,6 @@ use lsp_types::{
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
};
use parking_lot::Mutex;
use std::{
sync::{mpsc, Arc},
thread,
@@ -75,33 +74,35 @@ fn main() -> Result<()> {
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
// Build our configuration
let configuration = Configuration::new(args)?;
let config = Configuration::new(args)?;
// Wrap the connection for sharing between threads
let connection = Arc::new(connection);
// Our channel we use to communicate with our transformer worker
let last_worker_request = Arc::new(Mutex::new(None));
// let last_worker_request = Arc::new(Mutex::new(None));
let (transformer_tx, transformer_rx) = mpsc::channel();
// The channel we use to communicate with our memory worker
let (memory_tx, memory_rx) = mpsc::channel();
// Setup the transformer worker
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = configuration.clone().try_into()?;
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = config.clone().try_into()?;
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
// Setup our transformer worker
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
configuration.clone().try_into()?;
let thread_last_worker_request = last_worker_request.clone();
config.clone().try_into()?;
let thread_connection = connection.clone();
let thread_memory_tx = memory_tx.clone();
let thread_config = config.clone();
thread::spawn(move || {
transformer_worker::run(
transformer_backend,
thread_memory_tx,
thread_last_worker_request,
transformer_rx,
thread_connection,
thread_config,
)
});
@@ -111,33 +112,28 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
if connection.handle_shutdown(&req)? {
return Ok(());
}
// Right now each if / else basically does the same thing,
// but this may change soon so it is worth making it a little
// more verbose than it needs to be now
if request_is::<Completion>(&req) {
match cast::<Completion>(req) {
Ok((id, params)) => {
let mut lcr = last_worker_request.lock();
let completion_request = CompletionRequest::new(id, params);
*lcr = Some(WorkerRequest::Completion(completion_request));
transformer_tx.send(WorkerRequest::Completion(completion_request))?;
}
Err(err) => error!("{err:?}"),
}
} else if request_is::<Generate>(&req) {
match cast::<Generate>(req) {
Ok((id, params)) => {
let mut lcr = last_worker_request.lock();
let completion_request = GenerateRequest::new(id, params);
*lcr = Some(WorkerRequest::Generate(completion_request));
let generate_request = GenerateRequest::new(id, params);
transformer_tx.send(WorkerRequest::Generate(generate_request))?;
}
Err(err) => error!("{err:?}"),
}
} else if request_is::<GenerateStream>(&req) {
match cast::<GenerateStream>(req) {
Ok((id, params)) => {
let mut lcr = last_worker_request.lock();
let completion_request = GenerateStreamRequest::new(id, params);
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
let generate_stream_request = GenerateStreamRequest::new(id, params);
transformer_tx
.send(WorkerRequest::GenerateStream(generate_stream_request))?;
}
Err(err) => error!("{err:?}"),
}

View File

@@ -3,11 +3,12 @@ use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit,
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::oneshot;
use tracing::error;
use tracing::{debug, error, instrument};
use crate::configuration::Configuration;
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::GenerateStreamParams;
use crate::memory_backends::PromptForType;
@@ -53,7 +54,7 @@ impl GenerateStreamRequest {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum WorkerRequest {
Completion(CompletionRequest),
Generate(GenerateRequest),
@@ -72,6 +73,7 @@ pub struct DoGenerateStreamResponse {
pub generated_text: String,
}
#[instrument(skip(transformer_backend, memory_backend_tx, connection))]
async fn do_task(
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
@@ -100,21 +102,22 @@ async fn do_task(
}
}
WorkerRequest::GenerateStream(_) => {
panic!("Streaming is not supported yet")
panic!("Streaming is not yet supported")
}
};
connection
.sender
.send(Message::Response(response))
.expect("Error sending message");
.expect("Error sending response");
Ok(())
}
fn do_run(
transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
connection: Arc<Connection>,
config: Configuration,
) -> anyhow::Result<()> {
let transformer_backend = Arc::new(transformer_backend);
let runtime = tokio::runtime::Builder::new_multi_thread()
@@ -122,43 +125,55 @@ fn do_run(
.enable_all()
.build()?;
// This logic is not perfect, but works well enough for now
let max_requests_per_second = config.get_transformer_max_requests_per_second();
let mut first_request = SystemTime::now();
let mut requests_in_last_5_seconds = 0.;
loop {
let option_worker_request: Option<WorkerRequest> = {
let mut completion_request = last_worker_request.lock();
std::mem::take(&mut *completion_request)
};
if let Some(request) = option_worker_request {
let thread_transformer_backend = transformer_backend.clone();
let thread_memory_backend_tx = memory_backend_tx.clone();
let thread_connection = connection.clone();
runtime.spawn(async move {
if let Err(e) = do_task(
thread_transformer_backend,
thread_memory_backend_tx,
request,
thread_connection,
)
.await
{
error!("error in transformer worker task: {e}")
}
});
let request = transformer_rx.recv()?;
if first_request.elapsed()?.as_secs() > 5 {
first_request = SystemTime::now();
requests_in_last_5_seconds = 0.;
}
thread::sleep(std::time::Duration::from_millis(5));
if requests_in_last_5_seconds / 5. > max_requests_per_second {
debug!("rate limiting transform request");
continue;
}
requests_in_last_5_seconds += 1.;
let thread_transformer_backend = transformer_backend.clone();
let thread_memory_backend_tx = memory_backend_tx.clone();
let thread_connection = connection.clone();
runtime.spawn(async move {
if let Err(e) = do_task(
thread_transformer_backend,
thread_memory_backend_tx,
request,
thread_connection,
)
.await
{
error!("transformer worker task: {e}")
}
});
}
}
pub fn run(
transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
connection: Arc<Connection>,
config: Configuration,
) {
if let Err(e) = do_run(
transformer_backend,
memory_backend_tx,
last_worker_request,
memory_tx,
transformer_rx,
connection,
config,
) {
error!("error in transformer worker: {e}")
}