mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 15:04:29 +01:00
Bump llamacpp-2 version and general configuration cleanups
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -1409,8 +1409,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-2"
|
||||
version = "0.1.34"
|
||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||
version = "0.1.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f391a790923a78bbe6100824124492c7df3d17b26340424eb813b88a521707a3"
|
||||
dependencies = [
|
||||
"llama-cpp-sys-2",
|
||||
"thiserror",
|
||||
@@ -1419,8 +1420,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-sys-2"
|
||||
version = "0.1.34"
|
||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||
version = "0.1.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1f26aac755fc36d5cc19f0853c4d359db8c4e4e5944705a27c7cce5e2cb9c36"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"cc",
|
||||
|
||||
@@ -19,8 +19,7 @@ tokenizers = "0.14.1"
|
||||
parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
# llama-cpp-2 = "0.1.31"
|
||||
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
|
||||
llama-cpp-2 = "0.1.47"
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
|
||||
@@ -97,6 +97,10 @@ pub struct Model {
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
const fn llamacpp_max_requests_per_second_default() -> f32 {
|
||||
0.25
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct LLaMACPP {
|
||||
// The model to use
|
||||
@@ -109,6 +113,9 @@ pub struct LLaMACPP {
|
||||
pub max_tokens: MaxTokens,
|
||||
// Chat args
|
||||
pub chat: Option<Chat>,
|
||||
// The maximum requests per second
|
||||
#[serde(default = "llamacpp_max_requests_per_second_default")]
|
||||
pub max_requests_per_second: f32,
|
||||
// Kwargs passed to LlamaCPP
|
||||
#[serde(flatten)]
|
||||
pub kwargs: Kwargs,
|
||||
@@ -128,28 +135,33 @@ impl Default for LLaMACPP {
|
||||
}),
|
||||
max_tokens: MaxTokens::default(),
|
||||
chat: None,
|
||||
max_requests_per_second: f32::MAX,
|
||||
kwargs: Kwargs::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn api_max_requests_per_second_default() -> f32 {
|
||||
0.5
|
||||
}
|
||||
|
||||
const fn openai_top_p_default() -> f32 {
|
||||
0.95
|
||||
}
|
||||
|
||||
const fn openai_presence_penalty() -> f32 {
|
||||
const fn openai_presence_penalty_default() -> f32 {
|
||||
0.
|
||||
}
|
||||
|
||||
const fn openai_frequency_penalty() -> f32 {
|
||||
const fn openai_frequency_penalty_default() -> f32 {
|
||||
0.
|
||||
}
|
||||
|
||||
const fn openai_temperature() -> f32 {
|
||||
const fn openai_temperature_default() -> f32 {
|
||||
0.1
|
||||
}
|
||||
|
||||
const fn openai_max_context() -> usize {
|
||||
const fn openai_max_context_default() -> usize {
|
||||
DEFAULT_OPENAI_MAX_CONTEXT
|
||||
}
|
||||
|
||||
@@ -162,6 +174,9 @@ pub struct OpenAI {
|
||||
pub completions_endpoint: Option<String>,
|
||||
// The chat endpoint
|
||||
pub chat_endpoint: Option<String>,
|
||||
// The maximum requests per second
|
||||
#[serde(default = "api_max_requests_per_second_default")]
|
||||
pub max_requests_per_second: f32,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// Fill in the middle support
|
||||
@@ -174,13 +189,13 @@ pub struct OpenAI {
|
||||
// Other available args
|
||||
#[serde(default = "openai_top_p_default")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "openai_presence_penalty")]
|
||||
#[serde(default = "openai_presence_penalty_default")]
|
||||
pub presence_penalty: f32,
|
||||
#[serde(default = "openai_frequency_penalty")]
|
||||
#[serde(default = "openai_frequency_penalty_default")]
|
||||
pub frequency_penalty: f32,
|
||||
#[serde(default = "openai_temperature")]
|
||||
#[serde(default = "openai_temperature_default")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "openai_max_context")]
|
||||
#[serde(default = "openai_max_context_default")]
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
@@ -193,6 +208,9 @@ pub struct Anthropic {
|
||||
pub completions_endpoint: Option<String>,
|
||||
// The chat endpoint
|
||||
pub chat_endpoint: Option<String>,
|
||||
// The maximum requests per second
|
||||
#[serde(default = "api_max_requests_per_second_default")]
|
||||
pub max_requests_per_second: f32,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// The maximum number of new tokens to generate
|
||||
@@ -203,15 +221,17 @@ pub struct Anthropic {
|
||||
// System prompt
|
||||
#[serde(default = "openai_top_p_default")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "openai_temperature")]
|
||||
#[serde(default = "openai_temperature_default")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "openai_max_context")]
|
||||
#[serde(default = "openai_max_context_default")]
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct ValidConfiguration {
|
||||
#[serde(default)]
|
||||
pub memory: ValidMemoryBackend,
|
||||
#[serde(default)]
|
||||
pub transformer: ValidTransformerBackend,
|
||||
}
|
||||
|
||||
@@ -246,9 +266,17 @@ impl Configuration {
|
||||
}
|
||||
|
||||
///////////////////////////////////////
|
||||
// Helpers for the Memory Backend /////
|
||||
// Helpers for the backends ///////////
|
||||
///////////////////////////////////////
|
||||
|
||||
pub fn get_transformer_max_requests_per_second(&self) -> f32 {
|
||||
match &self.config.transformer {
|
||||
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp.max_requests_per_second,
|
||||
ValidTransformerBackend::OpenAI(openai) => openai.max_requests_per_second,
|
||||
ValidTransformerBackend::Anthropic(anthropic) => anthropic.max_requests_per_second,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_max_context_length(&self) -> usize {
|
||||
match &self.config.transformer {
|
||||
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp
|
||||
|
||||
32
src/main.rs
32
src/main.rs
@@ -5,7 +5,6 @@ use lsp_types::{
|
||||
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
|
||||
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{
|
||||
sync::{mpsc, Arc},
|
||||
thread,
|
||||
@@ -75,33 +74,35 @@ fn main() -> Result<()> {
|
||||
|
||||
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
// Build our configuration
|
||||
let configuration = Configuration::new(args)?;
|
||||
let config = Configuration::new(args)?;
|
||||
|
||||
// Wrap the connection for sharing between threads
|
||||
let connection = Arc::new(connection);
|
||||
|
||||
// Our channel we use to communicate with our transformer worker
|
||||
let last_worker_request = Arc::new(Mutex::new(None));
|
||||
// let last_worker_request = Arc::new(Mutex::new(None));
|
||||
let (transformer_tx, transformer_rx) = mpsc::channel();
|
||||
|
||||
// The channel we use to communicate with our memory worker
|
||||
let (memory_tx, memory_rx) = mpsc::channel();
|
||||
|
||||
// Setup the transformer worker
|
||||
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = configuration.clone().try_into()?;
|
||||
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = config.clone().try_into()?;
|
||||
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
|
||||
|
||||
// Setup our transformer worker
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
configuration.clone().try_into()?;
|
||||
let thread_last_worker_request = last_worker_request.clone();
|
||||
config.clone().try_into()?;
|
||||
let thread_connection = connection.clone();
|
||||
let thread_memory_tx = memory_tx.clone();
|
||||
let thread_config = config.clone();
|
||||
thread::spawn(move || {
|
||||
transformer_worker::run(
|
||||
transformer_backend,
|
||||
thread_memory_tx,
|
||||
thread_last_worker_request,
|
||||
transformer_rx,
|
||||
thread_connection,
|
||||
thread_config,
|
||||
)
|
||||
});
|
||||
|
||||
@@ -111,33 +112,28 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
if connection.handle_shutdown(&req)? {
|
||||
return Ok(());
|
||||
}
|
||||
// Right now each if / else basically does the same thing,
|
||||
// but this may change soon so it is worth making it a little
|
||||
// more verbose than it needs to be now
|
||||
if request_is::<Completion>(&req) {
|
||||
match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = CompletionRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::Completion(completion_request));
|
||||
transformer_tx.send(WorkerRequest::Completion(completion_request))?;
|
||||
}
|
||||
Err(err) => error!("{err:?}"),
|
||||
}
|
||||
} else if request_is::<Generate>(&req) {
|
||||
match cast::<Generate>(req) {
|
||||
Ok((id, params)) => {
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = GenerateRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::Generate(completion_request));
|
||||
let generate_request = GenerateRequest::new(id, params);
|
||||
transformer_tx.send(WorkerRequest::Generate(generate_request))?;
|
||||
}
|
||||
Err(err) => error!("{err:?}"),
|
||||
}
|
||||
} else if request_is::<GenerateStream>(&req) {
|
||||
match cast::<GenerateStream>(req) {
|
||||
Ok((id, params)) => {
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = GenerateStreamRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
|
||||
let generate_stream_request = GenerateStreamRequest::new(id, params);
|
||||
transformer_tx
|
||||
.send(WorkerRequest::GenerateStream(generate_stream_request))?;
|
||||
}
|
||||
Err(err) => error!("{err:?}"),
|
||||
}
|
||||
|
||||
@@ -3,11 +3,12 @@ use lsp_types::{
|
||||
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
|
||||
Position, Range, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
use tracing::{debug, error, instrument};
|
||||
|
||||
use crate::configuration::Configuration;
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use crate::custom_requests::generate_stream::GenerateStreamParams;
|
||||
use crate::memory_backends::PromptForType;
|
||||
@@ -53,7 +54,7 @@ impl GenerateStreamRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum WorkerRequest {
|
||||
Completion(CompletionRequest),
|
||||
Generate(GenerateRequest),
|
||||
@@ -72,6 +73,7 @@ pub struct DoGenerateStreamResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
#[instrument(skip(transformer_backend, memory_backend_tx, connection))]
|
||||
async fn do_task(
|
||||
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
@@ -100,21 +102,22 @@ async fn do_task(
|
||||
}
|
||||
}
|
||||
WorkerRequest::GenerateStream(_) => {
|
||||
panic!("Streaming is not supported yet")
|
||||
panic!("Streaming is not yet supported")
|
||||
}
|
||||
};
|
||||
connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending message");
|
||||
.expect("Error sending response");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn do_run(
|
||||
transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
|
||||
connection: Arc<Connection>,
|
||||
config: Configuration,
|
||||
) -> anyhow::Result<()> {
|
||||
let transformer_backend = Arc::new(transformer_backend);
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
@@ -122,43 +125,55 @@ fn do_run(
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
// This logic is not perfect, but works well enough for now
|
||||
let max_requests_per_second = config.get_transformer_max_requests_per_second();
|
||||
let mut first_request = SystemTime::now();
|
||||
let mut requests_in_last_5_seconds = 0.;
|
||||
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let mut completion_request = last_worker_request.lock();
|
||||
std::mem::take(&mut *completion_request)
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let thread_transformer_backend = transformer_backend.clone();
|
||||
let thread_memory_backend_tx = memory_backend_tx.clone();
|
||||
let thread_connection = connection.clone();
|
||||
runtime.spawn(async move {
|
||||
if let Err(e) = do_task(
|
||||
thread_transformer_backend,
|
||||
thread_memory_backend_tx,
|
||||
request,
|
||||
thread_connection,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("error in transformer worker task: {e}")
|
||||
}
|
||||
});
|
||||
let request = transformer_rx.recv()?;
|
||||
|
||||
if first_request.elapsed()?.as_secs() > 5 {
|
||||
first_request = SystemTime::now();
|
||||
requests_in_last_5_seconds = 0.;
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
if requests_in_last_5_seconds / 5. > max_requests_per_second {
|
||||
debug!("rate limiting transform request");
|
||||
continue;
|
||||
}
|
||||
requests_in_last_5_seconds += 1.;
|
||||
|
||||
let thread_transformer_backend = transformer_backend.clone();
|
||||
let thread_memory_backend_tx = memory_backend_tx.clone();
|
||||
let thread_connection = connection.clone();
|
||||
runtime.spawn(async move {
|
||||
if let Err(e) = do_task(
|
||||
thread_transformer_backend,
|
||||
thread_memory_backend_tx,
|
||||
request,
|
||||
thread_connection,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("transformer worker task: {e}")
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
|
||||
connection: Arc<Connection>,
|
||||
config: Configuration,
|
||||
) {
|
||||
if let Err(e) = do_run(
|
||||
transformer_backend,
|
||||
memory_backend_tx,
|
||||
last_worker_request,
|
||||
memory_tx,
|
||||
transformer_rx,
|
||||
connection,
|
||||
config,
|
||||
) {
|
||||
error!("error in transformer worker: {e}")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user