mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-24 01:44:23 +01:00
Cleaned so much stuff up add tracing add chat formatting
This commit is contained in:
115
Cargo.lock
generated
115
Cargo.lock
generated
@@ -138,9 +138,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.86"
|
version = "1.0.88"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
|
checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
@@ -622,9 +622,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cpp-2"
|
name = "llama-cpp-2"
|
||||||
version = "0.1.31"
|
version = "0.1.34"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"llama-cpp-sys-2",
|
"llama-cpp-sys-2",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
@@ -633,9 +631,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cpp-sys-2"
|
name = "llama-cpp-sys-2"
|
||||||
version = "0.1.31"
|
version = "0.1.34"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen",
|
"bindgen",
|
||||||
"cc",
|
"cc",
|
||||||
@@ -675,6 +671,8 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
"tracing",
|
||||||
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -691,9 +689,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lsp-types"
|
name = "lsp-types"
|
||||||
version = "0.94.1"
|
version = "0.95.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1"
|
checksum = "158c1911354ef73e8fe42da6b10c0484cb65c7f1007f28022e847706c1ab6984"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -718,6 +716,15 @@ version = "0.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
|
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchers"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
|
||||||
|
dependencies = [
|
||||||
|
"regex-automata 0.1.10",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
version = "2.7.1"
|
version = "2.7.1"
|
||||||
@@ -797,6 +804,16 @@ dependencies = [
|
|||||||
"minimal-lexical",
|
"minimal-lexical",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nu-ansi-term"
|
||||||
|
version = "0.46.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||||
|
dependencies = [
|
||||||
|
"overload",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "number_prefix"
|
name = "number_prefix"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
@@ -881,6 +898,12 @@ version = "0.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "overload"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot"
|
name = "parking_lot"
|
||||||
version = "0.12.1"
|
version = "0.12.1"
|
||||||
@@ -1057,10 +1080,19 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
"regex-automata",
|
"regex-automata 0.4.5",
|
||||||
"regex-syntax 0.8.2",
|
"regex-syntax 0.8.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
|
||||||
|
dependencies = [
|
||||||
|
"regex-syntax 0.6.29",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-automata"
|
name = "regex-automata"
|
||||||
version = "0.4.5"
|
version = "0.4.5"
|
||||||
@@ -1072,6 +1104,12 @@ dependencies = [
|
|||||||
"regex-syntax 0.8.2",
|
"regex-syntax 0.8.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.6.29"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-syntax"
|
name = "regex-syntax"
|
||||||
version = "0.7.5"
|
version = "0.7.5"
|
||||||
@@ -1245,6 +1283,15 @@ dependencies = [
|
|||||||
"syn 2.0.50",
|
"syn 2.0.50",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sharded-slab"
|
||||||
|
version = "0.1.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
|
||||||
|
dependencies = [
|
||||||
|
"lazy_static",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@@ -1364,6 +1411,16 @@ dependencies = [
|
|||||||
"syn 2.0.50",
|
"syn 2.0.50",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thread_local"
|
||||||
|
version = "1.1.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"once_cell",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinyvec"
|
name = "tinyvec"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
@@ -1441,6 +1498,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
|
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"valuable",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-log"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"once_cell",
|
||||||
|
"tracing-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-subscriber"
|
||||||
|
version = "0.3.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
|
||||||
|
dependencies = [
|
||||||
|
"matchers",
|
||||||
|
"nu-ansi-term",
|
||||||
|
"once_cell",
|
||||||
|
"regex",
|
||||||
|
"sharded-slab",
|
||||||
|
"smallvec",
|
||||||
|
"thread_local",
|
||||||
|
"tracing",
|
||||||
|
"tracing-core",
|
||||||
|
"tracing-log",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1536,6 +1623,12 @@ version = "0.2.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "valuable"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "vcpkg"
|
name = "vcpkg"
|
||||||
version = "0.2.15"
|
version = "0.2.15"
|
||||||
|
|||||||
10
Cargo.toml
10
Cargo.toml
@@ -7,9 +7,9 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.75"
|
anyhow = "1.0.75"
|
||||||
lsp-server = "0.7.4"
|
lsp-server = "0.7.6"
|
||||||
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
|
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
|
||||||
lsp-types = "0.94.1"
|
lsp-types = "0.95.0"
|
||||||
ropey = "1.6.1"
|
ropey = "1.6.1"
|
||||||
serde = "1.0.190"
|
serde = "1.0.190"
|
||||||
serde_json = "1.0.108"
|
serde_json = "1.0.108"
|
||||||
@@ -19,9 +19,11 @@ tokenizers = "0.14.1"
|
|||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
directories = "5.0.1"
|
directories = "5.0.1"
|
||||||
llama-cpp-2 = "0.1.31"
|
# llama-cpp-2 = "0.1.31"
|
||||||
# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||||
minijinja = "1.0.12"
|
minijinja = "1.0.12"
|
||||||
|
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||||
|
tracing = "0.1.40"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ use serde::Deserialize;
|
|||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::memory_backends::Prompt;
|
|
||||||
|
|
||||||
#[cfg(target_os = "macos")]
|
#[cfg(target_os = "macos")]
|
||||||
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
||||||
|
|
||||||
@@ -26,7 +24,7 @@ pub enum ValidTransformerBackend {
|
|||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct ChatMessage {
|
pub struct ChatMessage {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub message: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
@@ -37,14 +35,14 @@ pub struct Chat {
|
|||||||
pub chat_format: Option<String>,
|
pub chat_format: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct FIM {
|
pub struct FIM {
|
||||||
pub start: String,
|
pub start: String,
|
||||||
pub middle: String,
|
pub middle: String,
|
||||||
pub end: String,
|
pub end: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct MaxNewTokens {
|
pub struct MaxNewTokens {
|
||||||
pub completion: usize,
|
pub completion: usize,
|
||||||
pub generation: usize,
|
pub generation: usize,
|
||||||
@@ -59,7 +57,7 @@ impl Default for MaxNewTokens {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
struct ValidMemoryConfiguration {
|
struct ValidMemoryConfiguration {
|
||||||
file_store: Option<Value>,
|
file_store: Option<Value>,
|
||||||
}
|
}
|
||||||
@@ -72,13 +70,13 @@ impl Default for ValidMemoryConfiguration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
pub repository: String,
|
pub repository: String,
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
struct ModelGGUF {
|
struct ModelGGUF {
|
||||||
// The model to use
|
// The model to use
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
@@ -114,7 +112,7 @@ impl Default for ModelGGUF {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
struct ValidMacTransformerConfiguration {
|
struct ValidMacTransformerConfiguration {
|
||||||
model_gguf: Option<ModelGGUF>,
|
model_gguf: Option<ModelGGUF>,
|
||||||
}
|
}
|
||||||
@@ -127,7 +125,7 @@ impl Default for ValidMacTransformerConfiguration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
struct ValidLinuxTransformerConfiguration {
|
struct ValidLinuxTransformerConfiguration {
|
||||||
model_gguf: Option<ModelGGUF>,
|
model_gguf: Option<ModelGGUF>,
|
||||||
}
|
}
|
||||||
@@ -140,7 +138,7 @@ impl Default for ValidLinuxTransformerConfiguration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Default)]
|
#[derive(Clone, Debug, Deserialize, Default)]
|
||||||
struct ValidConfiguration {
|
struct ValidConfiguration {
|
||||||
memory: ValidMemoryConfiguration,
|
memory: ValidMemoryConfiguration,
|
||||||
#[cfg(target_os = "macos")]
|
#[cfg(target_os = "macos")]
|
||||||
@@ -151,7 +149,7 @@ struct ValidConfiguration {
|
|||||||
transformer: ValidLinuxTransformerConfiguration,
|
transformer: ValidLinuxTransformerConfiguration,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Configuration {
|
pub struct Configuration {
|
||||||
valid_config: ValidConfiguration,
|
valid_config: ValidConfiguration,
|
||||||
}
|
}
|
||||||
@@ -175,7 +173,7 @@ impl Configuration {
|
|||||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
Ok(&model_gguf.model)
|
Ok(&model_gguf.model)
|
||||||
} else {
|
} else {
|
||||||
panic!("We currently only support gguf models using llama cpp")
|
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,7 +181,7 @@ impl Configuration {
|
|||||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
Ok(&model_gguf.kwargs)
|
Ok(&model_gguf.kwargs)
|
||||||
} else {
|
} else {
|
||||||
panic!("We currently only support gguf models using llama cpp")
|
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,9 +201,9 @@ impl Configuration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_maximum_context_length(&self) -> usize {
|
pub fn get_maximum_context_length(&self) -> Result<usize> {
|
||||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
model_gguf
|
Ok(model_gguf
|
||||||
.kwargs
|
.kwargs
|
||||||
.get("n_ctx")
|
.get("n_ctx")
|
||||||
.map(|v| {
|
.map(|v| {
|
||||||
@@ -213,91 +211,33 @@ impl Configuration {
|
|||||||
.map(|u| u as usize)
|
.map(|u| u as usize)
|
||||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
||||||
})
|
})
|
||||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX))
|
||||||
} else {
|
} else {
|
||||||
panic!("We currently only support gguf models using llama cpp")
|
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_max_new_tokens(&self) -> &MaxNewTokens {
|
pub fn get_max_new_tokens(&self) -> Result<&MaxNewTokens> {
|
||||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
&model_gguf.max_new_tokens
|
Ok(&model_gguf.max_new_tokens)
|
||||||
} else {
|
} else {
|
||||||
panic!("We currently only support gguf models using llama cpp")
|
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_fim(&self) -> Option<&FIM> {
|
pub fn get_fim(&self) -> Result<Option<&FIM>> {
|
||||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
model_gguf.fim.as_ref()
|
Ok(model_gguf.fim.as_ref())
|
||||||
} else {
|
} else {
|
||||||
panic!("We currently only support gguf models using llama cpp")
|
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_chat(&self) -> Option<&Chat> {
|
pub fn get_chat(&self) -> Result<Option<&Chat>> {
|
||||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
model_gguf.chat.as_ref()
|
Ok(model_gguf.chat.as_ref())
|
||||||
} else {
|
} else {
|
||||||
panic!("We currently only support gguf models using llama cpp")
|
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn custom_mac_gguf_model() {
|
|
||||||
let args = json!({
|
|
||||||
"initializationOptions": {
|
|
||||||
"memory": {
|
|
||||||
"file_store": {}
|
|
||||||
},
|
|
||||||
"macos": {
|
|
||||||
"model_gguf": {
|
|
||||||
// "repository": "deepseek-coder-6.7b-base",
|
|
||||||
// "name": "Q4_K_M.gguf",
|
|
||||||
"repository": "stabilityai/stablelm-2-zephyr-1_6b",
|
|
||||||
"name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
|
|
||||||
"max_new_tokens": {
|
|
||||||
"completion": 32,
|
|
||||||
"generation": 256,
|
|
||||||
},
|
|
||||||
"fim": {
|
|
||||||
"start": "<fim_prefix>",
|
|
||||||
"middle": "<fim_suffix>",
|
|
||||||
"end": "<fim_middle>"
|
|
||||||
},
|
|
||||||
"chat": {
|
|
||||||
"completion": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"message": "Complete the following code: \n\n{code}"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"generation": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"message": "Complete the following code: \n\n{code}"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"n_ctx": 2048,
|
|
||||||
"n_gpu_layers": 35,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let _ = Configuration::new(args).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
86
src/main.rs
86
src/main.rs
@@ -7,12 +7,12 @@ use lsp_types::{
|
|||||||
};
|
};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::{sync::Arc, thread};
|
use std::{sync::Arc, thread};
|
||||||
|
use tracing::error;
|
||||||
|
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||||
|
|
||||||
mod configuration;
|
mod configuration;
|
||||||
mod custom_requests;
|
mod custom_requests;
|
||||||
mod memory_backends;
|
mod memory_backends;
|
||||||
mod template;
|
|
||||||
mod tokenizer;
|
|
||||||
mod transformer_backends;
|
mod transformer_backends;
|
||||||
mod utils;
|
mod utils;
|
||||||
mod worker;
|
mod worker;
|
||||||
@@ -43,6 +43,13 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
|
||||||
|
// If the variables value is malformed or missing, sets the default log level to ERROR
|
||||||
|
FmtSubscriber::builder()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
|
||||||
|
.init();
|
||||||
|
|
||||||
let (connection, io_threads) = Connection::stdio();
|
let (connection, io_threads) = Connection::stdio();
|
||||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||||
completion_provider: Some(CompletionOptions::default()),
|
completion_provider: Some(CompletionOptions::default()),
|
||||||
@@ -104,12 +111,11 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
if request_is::<Completion>(&req) {
|
if request_is::<Completion>(&req) {
|
||||||
match cast::<Completion>(req) {
|
match cast::<Completion>(req) {
|
||||||
Ok((id, params)) => {
|
Ok((id, params)) => {
|
||||||
eprintln!("******{:?}********", id);
|
|
||||||
let mut lcr = last_worker_request.lock();
|
let mut lcr = last_worker_request.lock();
|
||||||
let completion_request = CompletionRequest::new(id, params);
|
let completion_request = CompletionRequest::new(id, params);
|
||||||
*lcr = Some(WorkerRequest::Completion(completion_request));
|
*lcr = Some(WorkerRequest::Completion(completion_request));
|
||||||
}
|
}
|
||||||
Err(err) => eprintln!("{err:?}"),
|
Err(err) => error!("{err:?}"),
|
||||||
}
|
}
|
||||||
} else if request_is::<Generate>(&req) {
|
} else if request_is::<Generate>(&req) {
|
||||||
match cast::<Generate>(req) {
|
match cast::<Generate>(req) {
|
||||||
@@ -118,7 +124,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
let completion_request = GenerateRequest::new(id, params);
|
let completion_request = GenerateRequest::new(id, params);
|
||||||
*lcr = Some(WorkerRequest::Generate(completion_request));
|
*lcr = Some(WorkerRequest::Generate(completion_request));
|
||||||
}
|
}
|
||||||
Err(err) => eprintln!("{err:?}"),
|
Err(err) => error!("{err:?}"),
|
||||||
}
|
}
|
||||||
} else if request_is::<GenerateStream>(&req) {
|
} else if request_is::<GenerateStream>(&req) {
|
||||||
match cast::<GenerateStream>(req) {
|
match cast::<GenerateStream>(req) {
|
||||||
@@ -127,10 +133,10 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
let completion_request = GenerateStreamRequest::new(id, params);
|
let completion_request = GenerateStreamRequest::new(id, params);
|
||||||
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
|
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
|
||||||
}
|
}
|
||||||
Err(err) => eprintln!("{err:?}"),
|
Err(err) => error!("{err:?}"),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
|
error!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Message::Notification(not) => {
|
Message::Notification(not) => {
|
||||||
@@ -150,3 +156,69 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::memory_backends::Prompt;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_mac_gguf_model() {
|
||||||
|
let args = json!({
|
||||||
|
"initializationOptions": {
|
||||||
|
"memory": {
|
||||||
|
"file_store": {}
|
||||||
|
},
|
||||||
|
"macos": {
|
||||||
|
"model_gguf": {
|
||||||
|
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
|
||||||
|
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
|
||||||
|
// "repository": "stabilityai/stablelm-2-zephyr-1_6b",
|
||||||
|
// "name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
|
||||||
|
"max_new_tokens": {
|
||||||
|
"completion": 32,
|
||||||
|
"generation": 256,
|
||||||
|
},
|
||||||
|
// "fim": {
|
||||||
|
// "start": "<fim_prefix>",
|
||||||
|
// "middle": "<fim_suffix>",
|
||||||
|
// "end": "<fim_middle>"
|
||||||
|
// },
|
||||||
|
// "chat": {
|
||||||
|
// "completion": [
|
||||||
|
// {
|
||||||
|
// "role": "system",
|
||||||
|
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// "role": "user",
|
||||||
|
// "content": "Complete the following code: \n\n{code}"
|
||||||
|
// }
|
||||||
|
// ],
|
||||||
|
// "generation": [
|
||||||
|
// {
|
||||||
|
// "role": "system",
|
||||||
|
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// "role": "user",
|
||||||
|
// "content": "Complete the following code: \n\n{code}"
|
||||||
|
// }
|
||||||
|
// ],
|
||||||
|
// // "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||||
|
// },
|
||||||
|
"n_ctx": 2048,
|
||||||
|
"n_gpu_layers": 35,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let configuration = Configuration::new(args).unwrap();
|
||||||
|
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||||
|
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||||
|
let response = backend.do_completion(&prompt).unwrap();
|
||||||
|
eprintln!("\nRESPONSE:\n{:?}", response.insert_text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ use anyhow::Context;
|
|||||||
use lsp_types::TextDocumentPositionParams;
|
use lsp_types::TextDocumentPositionParams;
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
|
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
|
||||||
|
|
||||||
use super::{MemoryBackend, Prompt};
|
use super::{MemoryBackend, Prompt, PromptForType};
|
||||||
|
|
||||||
pub struct FileStore {
|
pub struct FileStore {
|
||||||
configuration: Configuration,
|
configuration: Configuration,
|
||||||
@@ -22,6 +23,7 @@ impl FileStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MemoryBackend for FileStore {
|
impl MemoryBackend for FileStore {
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||||
let rope = self
|
let rope = self
|
||||||
.file_map
|
.file_map
|
||||||
@@ -34,7 +36,12 @@ impl MemoryBackend for FileStore {
|
|||||||
.to_string())
|
.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt> {
|
#[instrument(skip(self))]
|
||||||
|
fn build_prompt(
|
||||||
|
&self,
|
||||||
|
position: &TextDocumentPositionParams,
|
||||||
|
prompt_for_type: PromptForType,
|
||||||
|
) -> anyhow::Result<Prompt> {
|
||||||
let mut rope = self
|
let mut rope = self
|
||||||
.file_map
|
.file_map
|
||||||
.get(position.text_document.uri.as_str())
|
.get(position.text_document.uri.as_str())
|
||||||
@@ -44,33 +51,60 @@ impl MemoryBackend for FileStore {
|
|||||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||||
+ position.position.character as usize;
|
+ position.position.character as usize;
|
||||||
|
|
||||||
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
|
let is_chat_enabled = match prompt_for_type {
|
||||||
let code = match self.configuration.get_fim() {
|
PromptForType::Completion => self
|
||||||
Some(fim) if rope.len_chars() != cursor_index => {
|
.configuration
|
||||||
let max_length =
|
.get_chat()?
|
||||||
characters_to_estimated_tokens(self.configuration.get_maximum_context_length());
|
.map(|c| c.completion.is_some())
|
||||||
|
.unwrap_or(false),
|
||||||
|
PromptForType::Generate => self
|
||||||
|
.configuration
|
||||||
|
.get_chat()?
|
||||||
|
.map(|c| c.generation.is_some())
|
||||||
|
.unwrap_or(false),
|
||||||
|
};
|
||||||
|
|
||||||
|
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
|
||||||
|
// and the user has not enabled chat
|
||||||
|
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
|
||||||
|
r @ (true, _) | r @ (false, Some(_))
|
||||||
|
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||||
|
{
|
||||||
|
let max_length = characters_to_estimated_tokens(
|
||||||
|
self.configuration.get_maximum_context_length()?,
|
||||||
|
);
|
||||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
||||||
let end = rope
|
let end = rope
|
||||||
.len_chars()
|
.len_chars()
|
||||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||||
rope.insert(end, &fim.end);
|
|
||||||
rope.insert(cursor_index, &fim.middle);
|
if is_chat_enabled {
|
||||||
rope.insert(start, &fim.start);
|
rope.insert(cursor_index, "{CURSOR}");
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
.get_slice(
|
.get_slice(start..end + "{CURSOR}".chars().count())
|
||||||
start
|
.context("Error getting rope slice")?;
|
||||||
..end
|
rope_slice.to_string()
|
||||||
+ fim.start.chars().count()
|
} else {
|
||||||
+ fim.middle.chars().count()
|
let fim = r.1.unwrap(); // We can unwrap as we know it is some from the match
|
||||||
+ fim.end.chars().count(),
|
rope.insert(end, &fim.end);
|
||||||
)
|
rope.insert(cursor_index, &fim.middle);
|
||||||
.context("Error getting rope slice")?;
|
rope.insert(start, &fim.start);
|
||||||
rope_slice.to_string()
|
let rope_slice = rope
|
||||||
|
.get_slice(
|
||||||
|
start
|
||||||
|
..end
|
||||||
|
+ fim.start.chars().count()
|
||||||
|
+ fim.middle.chars().count()
|
||||||
|
+ fim.end.chars().count(),
|
||||||
|
)
|
||||||
|
.context("Error getting rope slice")?;
|
||||||
|
rope_slice.to_string()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let start = cursor_index
|
let start = cursor_index
|
||||||
.checked_sub(characters_to_estimated_tokens(
|
.checked_sub(characters_to_estimated_tokens(
|
||||||
self.configuration.get_maximum_context_length(),
|
self.configuration.get_maximum_context_length()?,
|
||||||
))
|
))
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
@@ -82,6 +116,7 @@ impl MemoryBackend for FileStore {
|
|||||||
Ok(Prompt::new("".to_string(), code))
|
Ok(Prompt::new("".to_string(), code))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn opened_text_document(
|
fn opened_text_document(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: lsp_types::DidOpenTextDocumentParams,
|
params: lsp_types::DidOpenTextDocumentParams,
|
||||||
@@ -92,6 +127,7 @@ impl MemoryBackend for FileStore {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn changed_text_document(
|
fn changed_text_document(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: lsp_types::DidChangeTextDocumentParams,
|
params: lsp_types::DidChangeTextDocumentParams,
|
||||||
@@ -116,6 +152,7 @@ impl MemoryBackend for FileStore {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||||
for file_rename in params.files {
|
for file_rename in params.files {
|
||||||
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
||||||
|
|||||||
@@ -14,11 +14,17 @@ pub struct Prompt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Prompt {
|
impl Prompt {
|
||||||
fn new(context: String, code: String) -> Self {
|
pub fn new(context: String, code: String) -> Self {
|
||||||
Self { context, code }
|
Self { context, code }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum PromptForType {
|
||||||
|
Completion,
|
||||||
|
Generate,
|
||||||
|
}
|
||||||
|
|
||||||
pub trait MemoryBackend {
|
pub trait MemoryBackend {
|
||||||
fn init(&self) -> anyhow::Result<()> {
|
fn init(&self) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -26,7 +32,11 @@ pub trait MemoryBackend {
|
|||||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt>;
|
fn build_prompt(
|
||||||
|
&self,
|
||||||
|
position: &TextDocumentPositionParams,
|
||||||
|
prompt_for_type: PromptForType,
|
||||||
|
) -> anyhow::Result<Prompt>;
|
||||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
use crate::{
|
|
||||||
configuration::{Chat, ChatMessage, Configuration},
|
|
||||||
tokenizer::Tokenizer,
|
|
||||||
};
|
|
||||||
use hf_hub::api::sync::{Api, ApiRepo};
|
|
||||||
|
|
||||||
// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
|
||||||
// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
|
|
||||||
// const CHATML_BOS_TOKEN: &str = "<s>";
|
|
||||||
// const CHATML_EOS_TOKEN: &str = "<|im_end|>";
|
|
||||||
|
|
||||||
// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
|
||||||
// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
|
||||||
// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = "<s>";
|
|
||||||
// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = "</s>";
|
|
||||||
|
|
||||||
// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
|
||||||
// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
|
||||||
|
|
||||||
pub struct Template {
|
|
||||||
configuration: Configuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
// impl Template {
|
|
||||||
// pub fn new(configuration: Configuration) -> Self {
|
|
||||||
// Self { configuration }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn apply_prompt(
|
|
||||||
chat_messages: Vec<ChatMessage>,
|
|
||||||
chat: &Chat,
|
|
||||||
tokenizer: Option<&Tokenizer>,
|
|
||||||
) -> anyhow::Result<String> {
|
|
||||||
// If we have the chat template apply it
|
|
||||||
// If we have the chat_format see if we have it set
|
|
||||||
// If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file
|
|
||||||
anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file")
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
pub struct Tokenizer {}
|
|
||||||
|
|
||||||
impl Tokenizer {
|
|
||||||
pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result<Option<Self>> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
|
use tracing::{debug, instrument};
|
||||||
|
|
||||||
use super::TransformerBackend;
|
use super::TransformerBackend;
|
||||||
use crate::{
|
use crate::{
|
||||||
configuration::{Chat, Configuration},
|
configuration::Configuration,
|
||||||
memory_backends::Prompt,
|
memory_backends::Prompt,
|
||||||
template::{apply_prompt, Template},
|
|
||||||
tokenizer::Tokenizer,
|
|
||||||
utils::format_chat_messages,
|
utils::format_chat_messages,
|
||||||
worker::{
|
worker::{
|
||||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||||
@@ -19,10 +18,10 @@ use model::Model;
|
|||||||
pub struct LlamaCPP {
|
pub struct LlamaCPP {
|
||||||
model: Model,
|
model: Model,
|
||||||
configuration: Configuration,
|
configuration: Configuration,
|
||||||
tokenizer: Option<Tokenizer>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaCPP {
|
impl LlamaCPP {
|
||||||
|
#[instrument]
|
||||||
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
|
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model = configuration.get_model()?;
|
let model = configuration.get_model()?;
|
||||||
@@ -32,44 +31,53 @@ impl LlamaCPP {
|
|||||||
.context("Model `name` is required when using GGUF models")?;
|
.context("Model `name` is required when using GGUF models")?;
|
||||||
let repo = api.model(model.repository.to_owned());
|
let repo = api.model(model.repository.to_owned());
|
||||||
let model_path = repo.get(&name)?;
|
let model_path = repo.get(&name)?;
|
||||||
let tokenizer: Option<Tokenizer> = Tokenizer::maybe_from_repo(repo)?;
|
|
||||||
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
|
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
configuration,
|
configuration,
|
||||||
tokenizer,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl TransformerBackend for LlamaCPP {
|
#[instrument(skip(self))]
|
||||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
fn get_prompt_string(&self, prompt: &Prompt) -> anyhow::Result<String> {
|
||||||
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
|
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
|
||||||
let prompt = match self.configuration.get_chat() {
|
Ok(match self.configuration.get_chat()? {
|
||||||
Some(c) => {
|
Some(c) => {
|
||||||
if let Some(completion_messages) = &c.completion {
|
if let Some(completion_messages) = &c.completion {
|
||||||
let chat_messages = format_chat_messages(completion_messages, prompt);
|
let chat_messages = format_chat_messages(completion_messages, prompt);
|
||||||
apply_prompt(chat_messages, c, self.tokenizer.as_ref())?
|
self.model
|
||||||
|
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
|
||||||
} else {
|
} else {
|
||||||
prompt.code.to_owned()
|
prompt.code.to_owned()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => prompt.code.to_owned(),
|
None => prompt.code.to_owned(),
|
||||||
};
|
})
|
||||||
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransformerBackend for LlamaCPP {
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||||
|
let prompt = self.get_prompt_string(prompt)?;
|
||||||
|
// debug!("Prompt string for LLM: {}", prompt);
|
||||||
|
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
||||||
self.model
|
self.model
|
||||||
.complete(&prompt, max_new_tokens)
|
.complete(&prompt, max_new_tokens)
|
||||||
.map(|insert_text| DoCompletionResponse { insert_text })
|
.map(|insert_text| DoCompletionResponse { insert_text })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||||
unimplemented!()
|
let prompt = self.get_prompt_string(prompt)?;
|
||||||
// let max_new_tokens = self.configuration.get_max_new_tokens().generation;
|
// debug!("Prompt string for LLM: {}", prompt);
|
||||||
// self.model
|
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
||||||
// .complete(prompt, max_new_tokens)
|
self.model
|
||||||
// .map(|generated_text| DoGenerateResponse { generated_text })
|
.complete(&prompt, max_new_tokens)
|
||||||
|
.map(|generated_text| DoGenerateResponse { generated_text })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn do_generate_stream(
|
fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
_request: &GenerateStreamRequest,
|
_request: &GenerateStreamRequest,
|
||||||
@@ -132,7 +140,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
let configuration = Configuration::new(args).unwrap();
|
let configuration = Configuration::new(args).unwrap();
|
||||||
let model = LlamaCPP::new(configuration).unwrap();
|
let _model = LlamaCPP::new(configuration).unwrap();
|
||||||
// let output = model.do_completion("def fibon").unwrap();
|
// let output = model.do_completion("def fibon").unwrap();
|
||||||
// println!("{}", output.insert_text);
|
// println!("{}", output.insert_text);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ use llama_cpp_2::{
|
|||||||
ggml_time_us,
|
ggml_time_us,
|
||||||
llama_backend::LlamaBackend,
|
llama_backend::LlamaBackend,
|
||||||
llama_batch::LlamaBatch,
|
llama_batch::LlamaBatch,
|
||||||
model::{params::LlamaModelParams, AddBos, LlamaModel},
|
model::{params::LlamaModelParams, AddBos, LlamaChatMessage, LlamaModel},
|
||||||
token::data_array::LlamaTokenDataArray,
|
token::data_array::LlamaTokenDataArray,
|
||||||
};
|
};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
|
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
|
||||||
|
use tracing::{debug, info, instrument};
|
||||||
|
|
||||||
use crate::configuration::Kwargs;
|
use crate::configuration::{ChatMessage, Kwargs};
|
||||||
|
|
||||||
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
|
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
|
||||||
|
|
||||||
@@ -20,6 +21,7 @@ pub struct Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
|
#[instrument]
|
||||||
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
|
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
|
||||||
// Get n_gpu_layers if set in kwargs
|
// Get n_gpu_layers if set in kwargs
|
||||||
// As a default we set it to 1000, which should put all layers on the GPU
|
// As a default we set it to 1000, which should put all layers on the GPU
|
||||||
@@ -41,9 +43,8 @@ impl Model {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Load the model
|
// Load the model
|
||||||
eprintln!("SETTING MODEL AT PATH: {:?}", model_path);
|
debug!("Loading model at path: {:?}", model_path);
|
||||||
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
||||||
eprintln!("\nMODEL SET\n");
|
|
||||||
|
|
||||||
// Get n_ctx if set in kwargs
|
// Get n_ctx if set in kwargs
|
||||||
// As a default we set it to 2048
|
// As a default we set it to 2048
|
||||||
@@ -60,6 +61,7 @@ impl Model {
|
|||||||
Ok(Model { model, n_ctx })
|
Ok(Model { model, n_ctx })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
||||||
// initialize the context
|
// initialize the context
|
||||||
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
|
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
|
||||||
@@ -77,9 +79,7 @@ impl Model {
|
|||||||
let n_cxt = ctx.n_ctx() as usize;
|
let n_cxt = ctx.n_ctx() as usize;
|
||||||
let n_kv_req = tokens_list.len() + max_new_tokens;
|
let n_kv_req = tokens_list.len() + max_new_tokens;
|
||||||
|
|
||||||
eprintln!(
|
info!("n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
|
||||||
"n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}"
|
|
||||||
);
|
|
||||||
|
|
||||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||||
if n_kv_req > n_cxt {
|
if n_kv_req > n_cxt {
|
||||||
@@ -132,14 +132,29 @@ impl Model {
|
|||||||
|
|
||||||
let t_main_end = ggml_time_us();
|
let t_main_end = ggml_time_us();
|
||||||
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
|
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
|
||||||
eprintln!(
|
info!(
|
||||||
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
|
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
|
||||||
n_decode,
|
n_decode,
|
||||||
duration.as_secs_f32(),
|
duration.as_secs_f32(),
|
||||||
n_decode as f32 / duration.as_secs_f32()
|
n_decode as f32 / duration.as_secs_f32()
|
||||||
);
|
);
|
||||||
eprintln!("{}", ctx.timings());
|
info!("{}", ctx.timings());
|
||||||
|
|
||||||
Ok(output.join(""))
|
Ok(output.join(""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub fn apply_chat_template(
|
||||||
|
&self,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
template: Option<String>,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let llama_chat_messages = messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| LlamaChatMessage::new(c.role, c.content))
|
||||||
|
.collect::<Result<Vec<LlamaChatMessage>, _>>()?;
|
||||||
|
Ok(self
|
||||||
|
.model
|
||||||
|
.apply_chat_template(template, llama_chat_messages, true)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ pub trait ToResponseError {
|
|||||||
impl ToResponseError for anyhow::Error {
|
impl ToResponseError for anyhow::Error {
|
||||||
fn to_response_error(&self, code: i32) -> ResponseError {
|
fn to_response_error(&self, code: i32) -> ResponseError {
|
||||||
ResponseError {
|
ResponseError {
|
||||||
code: -32603,
|
code,
|
||||||
message: self.to_string(),
|
message: self.to_string(),
|
||||||
data: None,
|
data: None,
|
||||||
}
|
}
|
||||||
@@ -25,8 +25,8 @@ pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|m| ChatMessage {
|
.map(|m| ChatMessage {
|
||||||
role: m.role.to_owned(),
|
role: m.role.to_owned(),
|
||||||
message: m
|
content: m
|
||||||
.message
|
.content
|
||||||
.replace("{context}", &prompt.context)
|
.replace("{context}", &prompt.context)
|
||||||
.replace("{code}", &prompt.code),
|
.replace("{code}", &prompt.code),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,14 +5,15 @@ use lsp_types::{
|
|||||||
};
|
};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::{sync::Arc, thread};
|
use std::{sync::Arc, thread};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||||
use crate::custom_requests::generate_stream::GenerateStreamParams;
|
use crate::custom_requests::generate_stream::GenerateStreamParams;
|
||||||
use crate::memory_backends::MemoryBackend;
|
use crate::memory_backends::{MemoryBackend, PromptForType};
|
||||||
use crate::transformer_backends::TransformerBackend;
|
use crate::transformer_backends::TransformerBackend;
|
||||||
use crate::utils::ToResponseError;
|
use crate::utils::ToResponseError;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct CompletionRequest {
|
pub struct CompletionRequest {
|
||||||
id: RequestId,
|
id: RequestId,
|
||||||
params: CompletionParams,
|
params: CompletionParams,
|
||||||
@@ -24,7 +25,7 @@ impl CompletionRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GenerateRequest {
|
pub struct GenerateRequest {
|
||||||
id: RequestId,
|
id: RequestId,
|
||||||
params: GenerateParams,
|
params: GenerateParams,
|
||||||
@@ -38,7 +39,7 @@ impl GenerateRequest {
|
|||||||
|
|
||||||
// The generate stream is not yet ready but we don't want to remove it
|
// The generate stream is not yet ready but we don't want to remove it
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GenerateStreamRequest {
|
pub struct GenerateStreamRequest {
|
||||||
id: RequestId,
|
id: RequestId,
|
||||||
params: GenerateStreamParams,
|
params: GenerateStreamParams,
|
||||||
@@ -91,21 +92,17 @@ impl Worker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||||
let prompt = self
|
let prompt = self.memory_backend.lock().build_prompt(
|
||||||
.memory_backend
|
&request.params.text_document_position,
|
||||||
.lock()
|
PromptForType::Completion,
|
||||||
.build_prompt(&request.params.text_document_position)?;
|
)?;
|
||||||
let filter_text = self
|
let filter_text = self
|
||||||
.memory_backend
|
.memory_backend
|
||||||
.lock()
|
.lock()
|
||||||
.get_filter_text(&request.params.text_document_position)?;
|
.get_filter_text(&request.params.text_document_position)?;
|
||||||
eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt);
|
|
||||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
let response = self.transformer_backend.do_completion(&prompt)?;
|
||||||
eprintln!(
|
|
||||||
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n",
|
|
||||||
response.insert_text
|
|
||||||
);
|
|
||||||
let completion_text_edit = TextEdit::new(
|
let completion_text_edit = TextEdit::new(
|
||||||
Range::new(
|
Range::new(
|
||||||
Position::new(
|
Position::new(
|
||||||
@@ -139,12 +136,12 @@ impl Worker {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||||
let prompt = self
|
let prompt = self.memory_backend.lock().build_prompt(
|
||||||
.memory_backend
|
&request.params.text_document_position,
|
||||||
.lock()
|
PromptForType::Generate,
|
||||||
.build_prompt(&request.params.text_document_position)?;
|
)?;
|
||||||
eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt);
|
|
||||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
let response = self.transformer_backend.do_generate(&prompt)?;
|
||||||
let result = GenerateResult {
|
let result = GenerateResult {
|
||||||
generated_text: response.generated_text,
|
generated_text: response.generated_text,
|
||||||
|
|||||||
Reference in New Issue
Block a user