Cleaned so much stuff up add tracing add chat formatting

This commit is contained in:
Silas Marvin
2024-03-03 14:40:42 -08:00
parent 28b7b1b74e
commit 6627da705e
12 changed files with 353 additions and 225 deletions

115
Cargo.lock generated
View File

@@ -138,9 +138,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]]
name = "cc"
version = "1.0.86"
version = "1.0.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc"
dependencies = [
"libc",
]
@@ -622,9 +622,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "llama-cpp-2"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c"
version = "0.1.34"
dependencies = [
"llama-cpp-sys-2",
"thiserror",
@@ -633,9 +631,7 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36"
version = "0.1.34"
dependencies = [
"bindgen",
"cc",
@@ -675,6 +671,8 @@ dependencies = [
"serde",
"serde_json",
"tokenizers",
"tracing",
"tracing-subscriber",
]
[[package]]
@@ -691,9 +689,9 @@ dependencies = [
[[package]]
name = "lsp-types"
version = "0.94.1"
version = "0.95.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1"
checksum = "158c1911354ef73e8fe42da6b10c0484cb65c7f1007f28022e847706c1ab6984"
dependencies = [
"bitflags 1.3.2",
"serde",
@@ -718,6 +716,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "memchr"
version = "2.7.1"
@@ -797,6 +804,16 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "number_prefix"
version = "0.4.0"
@@ -881,6 +898,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.1"
@@ -1057,10 +1080,19 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-automata 0.4.5",
"regex-syntax 0.8.2",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
name = "regex-automata"
version = "0.4.5"
@@ -1072,6 +1104,12 @@ dependencies = [
"regex-syntax 0.8.2",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.7.5"
@@ -1245,6 +1283,15 @@ dependencies = [
"syn 2.0.50",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.3.0"
@@ -1364,6 +1411,16 @@ dependencies = [
"syn 2.0.50",
]
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
@@ -1441,6 +1498,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
@@ -1536,6 +1623,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "vcpkg"
version = "0.2.15"

View File

@@ -7,9 +7,9 @@ edition = "2021"
[dependencies]
anyhow = "1.0.75"
lsp-server = "0.7.4"
lsp-server = "0.7.6"
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
lsp-types = "0.94.1"
lsp-types = "0.95.0"
ropey = "1.6.1"
serde = "1.0.190"
serde_json = "1.0.108"
@@ -19,9 +19,11 @@ tokenizers = "0.14.1"
parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
llama-cpp-2 = "0.1.31"
# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
# llama-cpp-2 = "0.1.31"
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
minijinja = "1.0.12"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
[features]
default = []

View File

@@ -3,8 +3,6 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::memory_backends::Prompt;
#[cfg(target_os = "macos")]
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
@@ -26,7 +24,7 @@ pub enum ValidTransformerBackend {
#[derive(Debug, Clone, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub message: String,
pub content: String,
}
#[derive(Debug, Clone, Deserialize)]
@@ -37,14 +35,14 @@ pub struct Chat {
pub chat_format: Option<String>,
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct FIM {
pub start: String,
pub middle: String,
pub end: String,
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct MaxNewTokens {
pub completion: usize,
pub generation: usize,
@@ -59,7 +57,7 @@ impl Default for MaxNewTokens {
}
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
struct ValidMemoryConfiguration {
file_store: Option<Value>,
}
@@ -72,13 +70,13 @@ impl Default for ValidMemoryConfiguration {
}
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct Model {
pub repository: String,
pub name: Option<String>,
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
struct ModelGGUF {
// The model to use
#[serde(flatten)]
@@ -114,7 +112,7 @@ impl Default for ModelGGUF {
}
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
struct ValidMacTransformerConfiguration {
model_gguf: Option<ModelGGUF>,
}
@@ -127,7 +125,7 @@ impl Default for ValidMacTransformerConfiguration {
}
}
#[derive(Clone, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
struct ValidLinuxTransformerConfiguration {
model_gguf: Option<ModelGGUF>,
}
@@ -140,7 +138,7 @@ impl Default for ValidLinuxTransformerConfiguration {
}
}
#[derive(Clone, Deserialize, Default)]
#[derive(Clone, Debug, Deserialize, Default)]
struct ValidConfiguration {
memory: ValidMemoryConfiguration,
#[cfg(target_os = "macos")]
@@ -151,7 +149,7 @@ struct ValidConfiguration {
transformer: ValidLinuxTransformerConfiguration,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Configuration {
valid_config: ValidConfiguration,
}
@@ -175,7 +173,7 @@ impl Configuration {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
Ok(&model_gguf.model)
} else {
panic!("We currently only support gguf models using llama cpp")
anyhow::bail!("We currently only support gguf models using llama cpp")
}
}
@@ -183,7 +181,7 @@ impl Configuration {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
Ok(&model_gguf.kwargs)
} else {
panic!("We currently only support gguf models using llama cpp")
anyhow::bail!("We currently only support gguf models using llama cpp")
}
}
@@ -203,9 +201,9 @@ impl Configuration {
}
}
pub fn get_maximum_context_length(&self) -> usize {
pub fn get_maximum_context_length(&self) -> Result<usize> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf
Ok(model_gguf
.kwargs
.get("n_ctx")
.map(|v| {
@@ -213,91 +211,33 @@ impl Configuration {
.map(|u| u as usize)
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
})
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX))
} else {
panic!("We currently only support gguf models using llama cpp")
anyhow::bail!("We currently only support gguf models using llama cpp")
}
}
pub fn get_max_new_tokens(&self) -> &MaxNewTokens {
pub fn get_max_new_tokens(&self) -> Result<&MaxNewTokens> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
&model_gguf.max_new_tokens
Ok(&model_gguf.max_new_tokens)
} else {
panic!("We currently only support gguf models using llama cpp")
anyhow::bail!("We currently only support gguf models using llama cpp")
}
}
pub fn get_fim(&self) -> Option<&FIM> {
pub fn get_fim(&self) -> Result<Option<&FIM>> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf.fim.as_ref()
Ok(model_gguf.fim.as_ref())
} else {
panic!("We currently only support gguf models using llama cpp")
anyhow::bail!("We currently only support gguf models using llama cpp")
}
}
pub fn get_chat(&self) -> Option<&Chat> {
pub fn get_chat(&self) -> Result<Option<&Chat>> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf.chat.as_ref()
Ok(model_gguf.chat.as_ref())
} else {
panic!("We currently only support gguf models using llama cpp")
anyhow::bail!("We currently only support gguf models using llama cpp")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn custom_mac_gguf_model() {
let args = json!({
"initializationOptions": {
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
// "repository": "deepseek-coder-6.7b-base",
// "name": "Q4_K_M.gguf",
"repository": "stabilityai/stablelm-2-zephyr-1_6b",
"name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
"completion": [
{
"role": "system",
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
},
{
"role": "user",
"message": "Complete the following code: \n\n{code}"
}
],
"generation": [
{
"role": "system",
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
},
{
"role": "user",
"message": "Complete the following code: \n\n{code}"
}
]
},
"n_ctx": 2048,
"n_gpu_layers": 35,
}
},
}
});
let _ = Configuration::new(args).unwrap();
}
}

View File

@@ -7,12 +7,12 @@ use lsp_types::{
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use tracing::error;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod configuration;
mod custom_requests;
mod memory_backends;
mod template;
mod tokenizer;
mod transformer_backends;
mod utils;
mod worker;
@@ -43,6 +43,13 @@ where
}
fn main() -> Result<()> {
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.init();
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(&ServerCapabilities {
completion_provider: Some(CompletionOptions::default()),
@@ -104,12 +111,11 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
if request_is::<Completion>(&req) {
match cast::<Completion>(req) {
Ok((id, params)) => {
eprintln!("******{:?}********", id);
let mut lcr = last_worker_request.lock();
let completion_request = CompletionRequest::new(id, params);
*lcr = Some(WorkerRequest::Completion(completion_request));
}
Err(err) => eprintln!("{err:?}"),
Err(err) => error!("{err:?}"),
}
} else if request_is::<Generate>(&req) {
match cast::<Generate>(req) {
@@ -118,7 +124,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let completion_request = GenerateRequest::new(id, params);
*lcr = Some(WorkerRequest::Generate(completion_request));
}
Err(err) => eprintln!("{err:?}"),
Err(err) => error!("{err:?}"),
}
} else if request_is::<GenerateStream>(&req) {
match cast::<GenerateStream>(req) {
@@ -127,10 +133,10 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let completion_request = GenerateStreamRequest::new(id, params);
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
}
Err(err) => eprintln!("{err:?}"),
Err(err) => error!("{err:?}"),
}
} else {
eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
error!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
}
}
Message::Notification(not) => {
@@ -150,3 +156,69 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::memory_backends::Prompt;
use super::*;
use serde_json::json;
#[test]
fn custom_mac_gguf_model() {
let args = json!({
"initializationOptions": {
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
// "repository": "stabilityai/stablelm-2-zephyr-1_6b",
// "name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
// "fim": {
// "start": "<fim_prefix>",
// "middle": "<fim_suffix>",
// "end": "<fim_middle>"
// },
// "chat": {
// "completion": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
// "generation": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
// // "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
// },
"n_ctx": 2048,
"n_gpu_layers": 35,
}
},
}
});
let configuration = Configuration::new(args).unwrap();
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
let response = backend.do_completion(&prompt).unwrap();
eprintln!("\nRESPONSE:\n{:?}", response.insert_text);
}
}

View File

@@ -2,10 +2,11 @@ use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use ropey::Rope;
use std::collections::HashMap;
use tracing::instrument;
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
use super::{MemoryBackend, Prompt};
use super::{MemoryBackend, Prompt, PromptForType};
pub struct FileStore {
configuration: Configuration,
@@ -22,6 +23,7 @@ impl FileStore {
}
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
.file_map
@@ -34,7 +36,12 @@ impl MemoryBackend for FileStore {
.to_string())
}
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt> {
#[instrument(skip(self))]
fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
let mut rope = self
.file_map
.get(position.text_document.uri.as_str())
@@ -44,15 +51,41 @@ impl MemoryBackend for FileStore {
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
let code = match self.configuration.get_fim() {
Some(fim) if rope.len_chars() != cursor_index => {
let max_length =
characters_to_estimated_tokens(self.configuration.get_maximum_context_length());
let is_chat_enabled = match prompt_for_type {
PromptForType::Completion => self
.configuration
.get_chat()?
.map(|c| c.completion.is_some())
.unwrap_or(false),
PromptForType::Generate => self
.configuration
.get_chat()?
.map(|c| c.generation.is_some())
.unwrap_or(false),
};
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
// and the user has not enabled chat
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
r @ (true, _) | r @ (false, Some(_))
if is_chat_enabled || rope.len_chars() != cursor_index =>
{
let max_length = characters_to_estimated_tokens(
self.configuration.get_maximum_context_length()?,
);
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
if is_chat_enabled {
rope.insert(cursor_index, "{CURSOR}");
let rope_slice = rope
.get_slice(start..end + "{CURSOR}".chars().count())
.context("Error getting rope slice")?;
rope_slice.to_string()
} else {
let fim = r.1.unwrap(); // We can unwrap as we know it is some from the match
rope.insert(end, &fim.end);
rope.insert(cursor_index, &fim.middle);
rope.insert(start, &fim.start);
@@ -67,10 +100,11 @@ impl MemoryBackend for FileStore {
.context("Error getting rope slice")?;
rope_slice.to_string()
}
}
_ => {
let start = cursor_index
.checked_sub(characters_to_estimated_tokens(
self.configuration.get_maximum_context_length(),
self.configuration.get_maximum_context_length()?,
))
.unwrap_or(0);
let rope_slice = rope
@@ -82,6 +116,7 @@ impl MemoryBackend for FileStore {
Ok(Prompt::new("".to_string(), code))
}
#[instrument(skip(self))]
fn opened_text_document(
&mut self,
params: lsp_types::DidOpenTextDocumentParams,
@@ -92,6 +127,7 @@ impl MemoryBackend for FileStore {
Ok(())
}
#[instrument(skip(self))]
fn changed_text_document(
&mut self,
params: lsp_types::DidChangeTextDocumentParams,
@@ -116,6 +152,7 @@ impl MemoryBackend for FileStore {
Ok(())
}
#[instrument(skip(self))]
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files {
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {

View File

@@ -14,11 +14,17 @@ pub struct Prompt {
}
impl Prompt {
fn new(context: String, code: String) -> Self {
pub fn new(context: String, code: String) -> Self {
Self { context, code }
}
}
#[derive(Debug)]
pub enum PromptForType {
Completion,
Generate,
}
pub trait MemoryBackend {
fn init(&self) -> anyhow::Result<()> {
Ok(())
@@ -26,7 +32,11 @@ pub trait MemoryBackend {
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt>;
fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt>;
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
}

View File

@@ -1,39 +0,0 @@
use crate::{
configuration::{Chat, ChatMessage, Configuration},
tokenizer::Tokenizer,
};
use hf_hub::api::sync::{Api, ApiRepo};
// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
// const CHATML_BOS_TOKEN: &str = "<s>";
// const CHATML_EOS_TOKEN: &str = "<|im_end|>";
// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = "<s>";
// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = "</s>";
// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
pub struct Template {
configuration: Configuration,
}
// impl Template {
// pub fn new(configuration: Configuration) -> Self {
// Self { configuration }
// }
// }
pub fn apply_prompt(
chat_messages: Vec<ChatMessage>,
chat: &Chat,
tokenizer: Option<&Tokenizer>,
) -> anyhow::Result<String> {
// If we have the chat template apply it
// If we have the chat_format see if we have it set
// If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file
anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file")
}

View File

@@ -1,7 +0,0 @@
pub struct Tokenizer {}
impl Tokenizer {
pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result<Option<Self>> {
unimplemented!()
}
}

View File

@@ -1,12 +1,11 @@
use anyhow::Context;
use hf_hub::api::sync::Api;
use tracing::{debug, instrument};
use super::TransformerBackend;
use crate::{
configuration::{Chat, Configuration},
configuration::Configuration,
memory_backends::Prompt,
template::{apply_prompt, Template},
tokenizer::Tokenizer,
utils::format_chat_messages,
worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
@@ -19,10 +18,10 @@ use model::Model;
pub struct LlamaCPP {
model: Model,
configuration: Configuration,
tokenizer: Option<Tokenizer>,
}
impl LlamaCPP {
#[instrument]
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
let api = Api::new()?;
let model = configuration.get_model()?;
@@ -32,44 +31,53 @@ impl LlamaCPP {
.context("Model `name` is required when using GGUF models")?;
let repo = api.model(model.repository.to_owned());
let model_path = repo.get(&name)?;
let tokenizer: Option<Tokenizer> = Tokenizer::maybe_from_repo(repo)?;
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
Ok(Self {
model,
configuration,
tokenizer,
})
}
}
impl TransformerBackend for LlamaCPP {
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
#[instrument(skip(self))]
fn get_prompt_string(&self, prompt: &Prompt) -> anyhow::Result<String> {
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
let prompt = match self.configuration.get_chat() {
Ok(match self.configuration.get_chat()? {
Some(c) => {
if let Some(completion_messages) = &c.completion {
let chat_messages = format_chat_messages(completion_messages, prompt);
apply_prompt(chat_messages, c, self.tokenizer.as_ref())?
self.model
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
} else {
prompt.code.to_owned()
}
}
None => prompt.code.to_owned(),
};
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
})
}
}
impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model
.complete(&prompt, max_new_tokens)
.map(|insert_text| DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
unimplemented!()
// let max_new_tokens = self.configuration.get_max_new_tokens().generation;
// self.model
// .complete(prompt, max_new_tokens)
// .map(|generated_text| DoGenerateResponse { generated_text })
let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model
.complete(&prompt, max_new_tokens)
.map(|generated_text| DoGenerateResponse { generated_text })
}
#[instrument(skip(self))]
fn do_generate_stream(
&self,
_request: &GenerateStreamRequest,
@@ -132,7 +140,7 @@ mod tests {
}
});
let configuration = Configuration::new(args).unwrap();
let model = LlamaCPP::new(configuration).unwrap();
let _model = LlamaCPP::new(configuration).unwrap();
// let output = model.do_completion("def fibon").unwrap();
// println!("{}", output.insert_text);
}

View File

@@ -4,13 +4,14 @@ use llama_cpp_2::{
ggml_time_us,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel},
model::{params::LlamaModelParams, AddBos, LlamaChatMessage, LlamaModel},
token::data_array::LlamaTokenDataArray,
};
use once_cell::sync::Lazy;
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
use tracing::{debug, info, instrument};
use crate::configuration::Kwargs;
use crate::configuration::{ChatMessage, Kwargs};
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
@@ -20,6 +21,7 @@ pub struct Model {
}
impl Model {
#[instrument]
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
// Get n_gpu_layers if set in kwargs
// As a default we set it to 1000, which should put all layers on the GPU
@@ -41,9 +43,8 @@ impl Model {
};
// Load the model
eprintln!("SETTING MODEL AT PATH: {:?}", model_path);
debug!("Loading model at path: {:?}", model_path);
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
eprintln!("\nMODEL SET\n");
// Get n_ctx if set in kwargs
// As a default we set it to 2048
@@ -60,6 +61,7 @@ impl Model {
Ok(Model { model, n_ctx })
}
#[instrument(skip(self))]
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
// initialize the context
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
@@ -77,9 +79,7 @@ impl Model {
let n_cxt = ctx.n_ctx() as usize;
let n_kv_req = tokens_list.len() + max_new_tokens;
eprintln!(
"n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}"
);
info!("n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if n_kv_req > n_cxt {
@@ -132,14 +132,29 @@ impl Model {
let t_main_end = ggml_time_us();
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
eprintln!(
info!(
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
n_decode,
duration.as_secs_f32(),
n_decode as f32 / duration.as_secs_f32()
);
eprintln!("{}", ctx.timings());
info!("{}", ctx.timings());
Ok(output.join(""))
}
#[instrument(skip(self))]
pub fn apply_chat_template(
&self,
messages: Vec<ChatMessage>,
template: Option<String>,
) -> anyhow::Result<String> {
let llama_chat_messages = messages
.into_iter()
.map(|c| LlamaChatMessage::new(c.role, c.content))
.collect::<Result<Vec<LlamaChatMessage>, _>>()?;
Ok(self
.model
.apply_chat_template(template, llama_chat_messages, true)?)
}
}

View File

@@ -9,7 +9,7 @@ pub trait ToResponseError {
impl ToResponseError for anyhow::Error {
fn to_response_error(&self, code: i32) -> ResponseError {
ResponseError {
code: -32603,
code,
message: self.to_string(),
data: None,
}
@@ -25,8 +25,8 @@ pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec
.iter()
.map(|m| ChatMessage {
role: m.role.to_owned(),
message: m
.message
content: m
.content
.replace("{context}", &prompt.context)
.replace("{code}", &prompt.code),
})

View File

@@ -5,14 +5,15 @@ use lsp_types::{
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use tracing::instrument;
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::GenerateStreamParams;
use crate::memory_backends::MemoryBackend;
use crate::memory_backends::{MemoryBackend, PromptForType};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct CompletionRequest {
id: RequestId,
params: CompletionParams,
@@ -24,7 +25,7 @@ impl CompletionRequest {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct GenerateRequest {
id: RequestId,
params: GenerateParams,
@@ -38,7 +39,7 @@ impl GenerateRequest {
// The generate stream is not yet ready but we don't want to remove it
#[allow(dead_code)]
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct GenerateStreamRequest {
id: RequestId,
params: GenerateStreamParams,
@@ -91,21 +92,17 @@ impl Worker {
}
}
#[instrument(skip(self))]
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
let prompt = self
.memory_backend
.lock()
.build_prompt(&request.params.text_document_position)?;
let prompt = self.memory_backend.lock().build_prompt(
&request.params.text_document_position,
PromptForType::Completion,
)?;
let filter_text = self
.memory_backend
.lock()
.get_filter_text(&request.params.text_document_position)?;
eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt);
let response = self.transformer_backend.do_completion(&prompt)?;
eprintln!(
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n",
response.insert_text
);
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
@@ -139,12 +136,12 @@ impl Worker {
})
}
#[instrument(skip(self))]
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
let prompt = self
.memory_backend
.lock()
.build_prompt(&request.params.text_document_position)?;
eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt);
let prompt = self.memory_backend.lock().build_prompt(
&request.params.text_document_position,
PromptForType::Generate,
)?;
let response = self.transformer_backend.do_generate(&prompt)?;
let result = GenerateResult {
generated_text: response.generated_text,