mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 23:14:28 +01:00
Cleaned up llamacpp stuff
This commit is contained in:
@@ -34,6 +34,7 @@ async-trait = "0.1.78"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cublas = ["llama-cpp-2/cublas"]
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0.14"
|
||||
|
||||
@@ -218,19 +218,18 @@ pub struct ValidConfiguration {
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct ValidClientParams {
|
||||
#[serde(alias = "rootURI")]
|
||||
root_uri: Option<String>,
|
||||
workspace_folders: Option<Vec<String>>,
|
||||
_root_uri: Option<String>,
|
||||
_workspace_folders: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Configuration {
|
||||
pub config: ValidConfiguration,
|
||||
client_params: ValidClientParams,
|
||||
_client_params: ValidClientParams,
|
||||
}
|
||||
|
||||
impl Configuration {
|
||||
pub fn new(mut args: Value) -> Result<Self> {
|
||||
eprintln!("\n\n{}\n\n", args.to_string());
|
||||
let configuration_args = args
|
||||
.as_object_mut()
|
||||
.context("Server configuration must be a JSON object")?
|
||||
@@ -242,7 +241,7 @@ impl Configuration {
|
||||
let client_params: ValidClientParams = serde_json::from_value(args)?;
|
||||
Ok(Self {
|
||||
config: valid_args,
|
||||
client_params,
|
||||
_client_params: client_params,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ fn main() -> Result<()> {
|
||||
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
// Build our configuration
|
||||
let configuration = Configuration::new(args)?;
|
||||
eprintln!("GOT THE CONFIG: {:?}", configuration);
|
||||
|
||||
// Wrap the connection for sharing between threads
|
||||
let connection = Arc::new(connection);
|
||||
|
||||
@@ -3,7 +3,7 @@ use indexmap::IndexSet;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use parking_lot::Mutex;
|
||||
use ropey::Rope;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::collections::HashMap;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
@@ -14,29 +14,16 @@ use crate::{
|
||||
use super::{MemoryBackend, Prompt, PromptForType};
|
||||
|
||||
pub struct FileStore {
|
||||
crawl: bool,
|
||||
_crawl: bool,
|
||||
configuration: Configuration,
|
||||
file_map: Mutex<HashMap<String, Rope>>,
|
||||
accessed_files: Mutex<IndexSet<String>>,
|
||||
}
|
||||
|
||||
// TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant
|
||||
// files and then when asked for context it loads them in by the most recently accessed? That seems kind of silly honestly, but I could see
|
||||
// how users who want to use models with massive context lengths would just want their entire project as context for generation tasks
|
||||
// I'm not sure yet, this is something I need to think through more
|
||||
|
||||
// Ok here are some more ideas
|
||||
// We take a crawl arg which is a bool of true or false for file_store
|
||||
// If true we crawl until we get to the max_context_length and then we stop crawling
|
||||
// We keep track of the last opened / changed files, and prioritize those when building the context for our llms
|
||||
|
||||
// For memory backends like PostgresML, they will need to take some kind of max_context_length to crawl or something.
|
||||
// In other words, there needs to be some specification for how much they should be crawling because the limiting happens in the vector_recall
|
||||
impl FileStore {
|
||||
pub fn new(file_store_config: configuration::FileStore, configuration: Configuration) -> Self {
|
||||
// TODO: maybe crawl
|
||||
Self {
|
||||
crawl: file_store_config.crawl,
|
||||
_crawl: file_store_config.crawl,
|
||||
configuration,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
@@ -45,7 +32,7 @@ impl FileStore {
|
||||
|
||||
pub fn new_without_crawl(configuration: Configuration) -> Self {
|
||||
Self {
|
||||
crawl: false,
|
||||
_crawl: false,
|
||||
configuration,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
@@ -135,8 +122,7 @@ impl FileStore {
|
||||
.unwrap_or(false),
|
||||
};
|
||||
|
||||
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
|
||||
// and the user has not enabled chat
|
||||
// We only want to do FIM if the user has enabled it and the user has not enabled chat
|
||||
Ok(match (is_chat_enabled, self.configuration.get_fim()) {
|
||||
r @ (true, _) | r @ (false, Some(_))
|
||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||
|
||||
@@ -7,7 +7,7 @@ use anyhow::Context;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use pgml::{Collection, Pipeline};
|
||||
use serde_json::json;
|
||||
use tokio::{runtime::Runtime, time};
|
||||
use tokio::time;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
@@ -22,7 +22,6 @@ pub struct PostgresML {
|
||||
file_store: FileStore,
|
||||
collection: Collection,
|
||||
pipeline: Pipeline,
|
||||
runtime: Runtime,
|
||||
debounce_tx: Sender<String>,
|
||||
added_pipeline: bool,
|
||||
}
|
||||
@@ -62,12 +61,11 @@ impl PostgresML {
|
||||
.into(),
|
||||
),
|
||||
)?;
|
||||
// Create our own runtime
|
||||
// Setup up a debouncer for changed text documents
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(2)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
// Setup up a debouncer for changed text documents
|
||||
let mut task_collection = collection.clone();
|
||||
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
|
||||
runtime.spawn(async move {
|
||||
@@ -106,15 +104,11 @@ impl PostgresML {
|
||||
}
|
||||
}
|
||||
});
|
||||
// TODO: Maybe
|
||||
// Need to crawl the root path and or workspace folders
|
||||
// Or set some kind of did crawl for it
|
||||
Ok(Self {
|
||||
configuration,
|
||||
file_store,
|
||||
collection,
|
||||
pipeline,
|
||||
runtime,
|
||||
debounce_tx,
|
||||
added_pipeline: false,
|
||||
})
|
||||
|
||||
@@ -41,17 +41,13 @@ impl Anthropic {
|
||||
messages: Vec<ChatMessage>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
eprintln!(
|
||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||
messages
|
||||
);
|
||||
let client = reqwest::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `transformer->anthropic` to use an Anthropic");
|
||||
};
|
||||
let res: AnthropicChatResponse = client
|
||||
.post(
|
||||
@@ -111,7 +107,7 @@ impl TransformerBackend for Anthropic {
|
||||
let insert_text = match &self.configuration.chat.completion {
|
||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||
None => {
|
||||
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
|
||||
anyhow::bail!("Please set `transformer->anthropic->chat->completion` messages")
|
||||
}
|
||||
};
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
@@ -124,7 +120,7 @@ impl TransformerBackend for Anthropic {
|
||||
let generated_text = match &self.configuration.chat.generation {
|
||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||
None => {
|
||||
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
|
||||
anyhow::bail!("Please set `transformer->anthropic->chat->generation` messages")
|
||||
}
|
||||
};
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::Context;
|
||||
use hf_hub::api::sync::ApiBuilder;
|
||||
use tracing::{debug, instrument};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
configuration::{self},
|
||||
@@ -17,12 +17,12 @@ use model::Model;
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct LlamaCPP {
|
||||
pub struct LLaMACPP {
|
||||
model: Model,
|
||||
configuration: configuration::LLaMACPP,
|
||||
}
|
||||
|
||||
impl LlamaCPP {
|
||||
impl LLaMACPP {
|
||||
#[instrument]
|
||||
pub fn new(configuration: configuration::LLaMACPP) -> anyhow::Result<Self> {
|
||||
let api = ApiBuilder::new().with_progress(true).build()?;
|
||||
@@ -30,9 +30,9 @@ impl LlamaCPP {
|
||||
.model
|
||||
.name
|
||||
.as_ref()
|
||||
.context("Model `name` is required when using GGUF models")?;
|
||||
.context("Please set `transformer->llamacpp->name` to use LLaMA.cpp")?;
|
||||
let repo = api.model(configuration.model.repository.to_owned());
|
||||
let model_path = repo.get(&name)?;
|
||||
let model_path = repo.get(name)?;
|
||||
let model = Model::new(model_path, &configuration.kwargs)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
@@ -50,7 +50,7 @@ impl LlamaCPP {
|
||||
if let Some(chat_template) = &c.chat_template {
|
||||
let bos_token = self.model.get_bos_token()?;
|
||||
let eos_token = self.model.get_eos_token()?;
|
||||
apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)?
|
||||
apply_chat_template(chat_template, chat_messages, &bos_token, &eos_token)?
|
||||
} else {
|
||||
self.model.apply_chat_template(chat_messages, None)?
|
||||
}
|
||||
@@ -64,12 +64,10 @@ impl LlamaCPP {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
impl TransformerBackend for LLaMACPP {
|
||||
#[instrument(skip(self))]
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
let prompt = &prompt.code;
|
||||
debug!("Prompt string for LLM: {}", prompt);
|
||||
let prompt = self.get_prompt_string(prompt)?;
|
||||
let max_new_tokens = self.configuration.max_tokens.completion;
|
||||
self.model
|
||||
.complete(&prompt, max_new_tokens)
|
||||
@@ -78,9 +76,7 @@ impl TransformerBackend for LlamaCPP {
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
let prompt = &prompt.code;
|
||||
let prompt = self.get_prompt_string(prompt)?;
|
||||
let max_new_tokens = self.configuration.max_tokens.completion;
|
||||
self.model
|
||||
.complete(&prompt, max_new_tokens)
|
||||
@@ -96,60 +92,55 @@ impl TransformerBackend for LlamaCPP {
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use serde_json::json;
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// #[test]
|
||||
// fn test_gguf() {
|
||||
// let args = json!({
|
||||
// "initializationOptions": {
|
||||
// "memory": {
|
||||
// "file_store": {}
|
||||
// },
|
||||
// "model_gguf": {
|
||||
// "repository": "stabilityai/stable-code-3b",
|
||||
// "name": "stable-code-3b-Q5_K_M.gguf",
|
||||
// "max_new_tokens": {
|
||||
// "completion": 32,
|
||||
// "generation": 256,
|
||||
// },
|
||||
// // "fim": {
|
||||
// // "start": "",
|
||||
// // "middle": "",
|
||||
// // "end": ""
|
||||
// // },
|
||||
// "chat": {
|
||||
// "completion": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "message": "Complete the following code: \n\n{code}"
|
||||
// }
|
||||
// ],
|
||||
// "generation": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "message": "Complete the following code: \n\n{code}"
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
// "n_ctx": 2048,
|
||||
// "n_gpu_layers": 1000,
|
||||
// }
|
||||
// },
|
||||
// });
|
||||
// let configuration = Configuration::new(args).unwrap();
|
||||
// let _model = LlamaCPP::new(configuration).unwrap();
|
||||
// // let output = model.do_completion("def fibon").unwrap();
|
||||
// // println!("{}", output.insert_text);
|
||||
// }
|
||||
// }
|
||||
#[tokio::test]
|
||||
async fn llama_cpp_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::LLaMACPP = serde_json::from_value(json!({
|
||||
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
|
||||
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
|
||||
"max_new_tokens": {
|
||||
"completion": 32,
|
||||
"generation": 256,
|
||||
},
|
||||
"fim": {
|
||||
"start": "<fim_prefix>",
|
||||
"middle": "<fim_suffix>",
|
||||
"end": "<fim_middle>"
|
||||
},
|
||||
"chat": {
|
||||
// "completion": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "Complete the following code: \n\n{code}"
|
||||
// }
|
||||
// ],
|
||||
// "generation": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "Complete the following code: \n\n{code}"
|
||||
// }
|
||||
// ],
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||
},
|
||||
"n_ctx": 2048,
|
||||
"n_gpu_layers": 35,
|
||||
}))?;
|
||||
let llama_cpp = LLaMACPP::new(configuration).unwrap();
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = llama_cpp.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,20 +31,23 @@ impl Model {
|
||||
.unwrap_or_else(|| Ok(1000))?;
|
||||
|
||||
// Initialize the model_params
|
||||
let model_params = {
|
||||
#[cfg(feature = "cublas")]
|
||||
if !params.disable_gpu {
|
||||
LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers)
|
||||
} else {
|
||||
LlamaModelParams::default()
|
||||
}
|
||||
#[cfg(not(feature = "cublas"))]
|
||||
LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers)
|
||||
};
|
||||
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
|
||||
|
||||
// Load the model
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
debug!("Loading model at path: {:?}", model_path);
|
||||
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
||||
eprintln!();
|
||||
eprintln!("LOADED THE MODEL");
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
eprintln!();
|
||||
|
||||
// Get n_ctx if set in kwargs
|
||||
// As a default we set it to 2048
|
||||
@@ -65,7 +68,7 @@ impl Model {
|
||||
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
||||
// initialize the context
|
||||
let ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(Some(self.n_ctx.clone()))
|
||||
.with_n_ctx(Some(self.n_ctx))
|
||||
.with_n_batch(self.n_ctx.get());
|
||||
|
||||
let mut ctx = self
|
||||
|
||||
@@ -26,7 +26,7 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
|
||||
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
||||
match configuration.config.transformer {
|
||||
ValidTransformerBackend::LLaMACPP(model_gguf) => {
|
||||
Ok(Box::new(llama_cpp::LlamaCPP::new(model_gguf)?))
|
||||
Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?))
|
||||
}
|
||||
ValidTransformerBackend::OpenAI(openai_config) => {
|
||||
Ok(Box::new(openai::OpenAI::new(openai_config)))
|
||||
|
||||
Reference in New Issue
Block a user