Cleaned up llamacpp stuff

This commit is contained in:
SilasMarvin
2024-04-03 20:37:00 -07:00
parent c31572e76f
commit db4de877d3
9 changed files with 91 additions and 120 deletions

View File

@@ -34,6 +34,7 @@ async-trait = "0.1.78"
[features]
default = []
cublas = ["llama-cpp-2/cublas"]
[dev-dependencies]
assert_cmd = "2.0.14"

View File

@@ -218,19 +218,18 @@ pub struct ValidConfiguration {
#[derive(Clone, Debug, Deserialize, Default)]
pub struct ValidClientParams {
#[serde(alias = "rootURI")]
root_uri: Option<String>,
workspace_folders: Option<Vec<String>>,
_root_uri: Option<String>,
_workspace_folders: Option<Vec<String>>,
}
#[derive(Clone, Debug)]
pub struct Configuration {
pub config: ValidConfiguration,
client_params: ValidClientParams,
_client_params: ValidClientParams,
}
impl Configuration {
pub fn new(mut args: Value) -> Result<Self> {
eprintln!("\n\n{}\n\n", args.to_string());
let configuration_args = args
.as_object_mut()
.context("Server configuration must be a JSON object")?
@@ -242,7 +241,7 @@ impl Configuration {
let client_params: ValidClientParams = serde_json::from_value(args)?;
Ok(Self {
config: valid_args,
client_params,
_client_params: client_params,
})
}

View File

@@ -74,6 +74,7 @@ fn main() -> Result<()> {
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
// Build our configuration
let configuration = Configuration::new(args)?;
eprintln!("GOT THE CONFIG: {:?}", configuration);
// Wrap the connection for sharing between threads
let connection = Arc::new(connection);

View File

@@ -3,7 +3,7 @@ use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use ropey::Rope;
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use tracing::instrument;
use crate::{
@@ -14,29 +14,16 @@ use crate::{
use super::{MemoryBackend, Prompt, PromptForType};
pub struct FileStore {
crawl: bool,
_crawl: bool,
configuration: Configuration,
file_map: Mutex<HashMap<String, Rope>>,
accessed_files: Mutex<IndexSet<String>>,
}
// TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant
// files and then when asked for context it loads them in by the most recently accessed? That seems kind of silly honestly, but I could see
// how users who want to use models with massive context lengths would just want their entire project as context for generation tasks
// I'm not sure yet, this is something I need to think through more
// Ok here are some more ideas
// We take a crawl arg which is a bool of true or false for file_store
// If true we crawl until we get to the max_context_length and then we stop crawling
// We keep track of the last opened / changed files, and prioritize those when building the context for our llms
// For memory backends like PostgresML, they will need to take some kind of max_context_length to crawl or something.
// In other words, there needs to be some specification for how much they should be crawling because the limiting happens in the vector_recall
impl FileStore {
pub fn new(file_store_config: configuration::FileStore, configuration: Configuration) -> Self {
// TODO: maybe crawl
Self {
crawl: file_store_config.crawl,
_crawl: file_store_config.crawl,
configuration,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
@@ -45,7 +32,7 @@ impl FileStore {
pub fn new_without_crawl(configuration: Configuration) -> Self {
Self {
crawl: false,
_crawl: false,
configuration,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
@@ -135,8 +122,7 @@ impl FileStore {
.unwrap_or(false),
};
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
// and the user has not enabled chat
// We only want to do FIM if the user has enabled it and the user has not enabled chat
Ok(match (is_chat_enabled, self.configuration.get_fim()) {
r @ (true, _) | r @ (false, Some(_))
if is_chat_enabled || rope.len_chars() != cursor_index =>

View File

@@ -7,7 +7,7 @@ use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use pgml::{Collection, Pipeline};
use serde_json::json;
use tokio::{runtime::Runtime, time};
use tokio::time;
use tracing::instrument;
use crate::{
@@ -22,7 +22,6 @@ pub struct PostgresML {
file_store: FileStore,
collection: Collection,
pipeline: Pipeline,
runtime: Runtime,
debounce_tx: Sender<String>,
added_pipeline: bool,
}
@@ -62,12 +61,11 @@ impl PostgresML {
.into(),
),
)?;
// Create our own runtime
// Setup up a debouncer for changed text documents
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()?;
// Setup up a debouncer for changed text documents
let mut task_collection = collection.clone();
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
runtime.spawn(async move {
@@ -106,15 +104,11 @@ impl PostgresML {
}
}
});
// TODO: Maybe
// Need to crawl the root path and or workspace folders
// Or set some kind of did crawl for it
Ok(Self {
configuration,
file_store,
collection,
pipeline,
runtime,
debounce_tx,
added_pipeline: false,
})

View File

@@ -41,17 +41,13 @@ impl Anthropic {
messages: Vec<ChatMessage>,
max_tokens: usize,
) -> anyhow::Result<String> {
eprintln!(
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
messages
);
let client = reqwest::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
} else if let Some(token) = &self.configuration.auth_token {
token.to_string()
} else {
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `transformer->anthropic` to use an Anthropic");
};
let res: AnthropicChatResponse = client
.post(
@@ -111,7 +107,7 @@ impl TransformerBackend for Anthropic {
let insert_text = match &self.configuration.chat.completion {
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
None => {
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
anyhow::bail!("Please set `transformer->anthropic->chat->completion` messages")
}
};
Ok(DoCompletionResponse { insert_text })
@@ -124,7 +120,7 @@ impl TransformerBackend for Anthropic {
let generated_text = match &self.configuration.chat.generation {
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
None => {
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
anyhow::bail!("Please set `transformer->anthropic->chat->generation` messages")
}
};
Ok(DoGenerateResponse { generated_text })

View File

@@ -1,6 +1,6 @@
use anyhow::Context;
use hf_hub::api::sync::ApiBuilder;
use tracing::{debug, instrument};
use tracing::instrument;
use crate::{
configuration::{self},
@@ -17,12 +17,12 @@ use model::Model;
use super::TransformerBackend;
pub struct LlamaCPP {
pub struct LLaMACPP {
model: Model,
configuration: configuration::LLaMACPP,
}
impl LlamaCPP {
impl LLaMACPP {
#[instrument]
pub fn new(configuration: configuration::LLaMACPP) -> anyhow::Result<Self> {
let api = ApiBuilder::new().with_progress(true).build()?;
@@ -30,9 +30,9 @@ impl LlamaCPP {
.model
.name
.as_ref()
.context("Model `name` is required when using GGUF models")?;
.context("Please set `transformer->llamacpp->name` to use LLaMA.cpp")?;
let repo = api.model(configuration.model.repository.to_owned());
let model_path = repo.get(&name)?;
let model_path = repo.get(name)?;
let model = Model::new(model_path, &configuration.kwargs)?;
Ok(Self {
model,
@@ -50,7 +50,7 @@ impl LlamaCPP {
if let Some(chat_template) = &c.chat_template {
let bos_token = self.model.get_bos_token()?;
let eos_token = self.model.get_eos_token()?;
apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)?
apply_chat_template(chat_template, chat_messages, &bos_token, &eos_token)?
} else {
self.model.apply_chat_template(chat_messages, None)?
}
@@ -64,12 +64,10 @@ impl LlamaCPP {
}
#[async_trait::async_trait]
impl TransformerBackend for LlamaCPP {
impl TransformerBackend for LLaMACPP {
#[instrument(skip(self))]
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
// let prompt = self.get_prompt_string(prompt)?;
let prompt = &prompt.code;
debug!("Prompt string for LLM: {}", prompt);
let prompt = self.get_prompt_string(prompt)?;
let max_new_tokens = self.configuration.max_tokens.completion;
self.model
.complete(&prompt, max_new_tokens)
@@ -78,9 +76,7 @@ impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))]
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
// let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
let prompt = &prompt.code;
let prompt = self.get_prompt_string(prompt)?;
let max_new_tokens = self.configuration.max_tokens.completion;
self.model
.complete(&prompt, max_new_tokens)
@@ -96,60 +92,55 @@ impl TransformerBackend for LlamaCPP {
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use serde_json::json;
#[cfg(test)]
mod test {
use super::*;
use serde_json::json;
// #[test]
// fn test_gguf() {
// let args = json!({
// "initializationOptions": {
// "memory": {
// "file_store": {}
// },
// "model_gguf": {
// "repository": "stabilityai/stable-code-3b",
// "name": "stable-code-3b-Q5_K_M.gguf",
// "max_new_tokens": {
// "completion": 32,
// "generation": 256,
// },
// // "fim": {
// // "start": "",
// // "middle": "",
// // "end": ""
// // },
// "chat": {
// "completion": [
// {
// "role": "system",
// "message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
// },
// {
// "role": "user",
// "message": "Complete the following code: \n\n{code}"
// }
// ],
// "generation": [
// {
// "role": "system",
// "message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
// },
// {
// "role": "user",
// "message": "Complete the following code: \n\n{code}"
// }
// ]
// },
// "n_ctx": 2048,
// "n_gpu_layers": 1000,
// }
// },
// });
// let configuration = Configuration::new(args).unwrap();
// let _model = LlamaCPP::new(configuration).unwrap();
// // let output = model.do_completion("def fibon").unwrap();
// // println!("{}", output.insert_text);
// }
// }
#[tokio::test]
async fn llama_cpp_do_completion() -> anyhow::Result<()> {
let configuration: configuration::LLaMACPP = serde_json::from_value(json!({
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
// "completion": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
// "generation": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
},
"n_ctx": 2048,
"n_gpu_layers": 35,
}))?;
let llama_cpp = LLaMACPP::new(configuration).unwrap();
let prompt = Prompt::default_with_cursor();
let response = llama_cpp.do_completion(&prompt).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
}

View File

@@ -31,20 +31,23 @@ impl Model {
.unwrap_or_else(|| Ok(1000))?;
// Initialize the model_params
let model_params = {
#[cfg(feature = "cublas")]
if !params.disable_gpu {
LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers)
} else {
LlamaModelParams::default()
}
#[cfg(not(feature = "cublas"))]
LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers)
};
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
// Load the model
eprintln!();
eprintln!();
eprintln!();
eprintln!();
eprintln!();
eprintln!();
debug!("Loading model at path: {:?}", model_path);
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
eprintln!();
eprintln!("LOADED THE MODEL");
eprintln!();
eprintln!();
eprintln!();
eprintln!();
// Get n_ctx if set in kwargs
// As a default we set it to 2048
@@ -65,7 +68,7 @@ impl Model {
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
// initialize the context
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(self.n_ctx.clone()))
.with_n_ctx(Some(self.n_ctx))
.with_n_batch(self.n_ctx.get());
let mut ctx = self

View File

@@ -26,7 +26,7 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
match configuration.config.transformer {
ValidTransformerBackend::LLaMACPP(model_gguf) => {
Ok(Box::new(llama_cpp::LlamaCPP::new(model_gguf)?))
Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?))
}
ValidTransformerBackend::OpenAI(openai_config) => {
Ok(Box::new(openai::OpenAI::new(openai_config)))