mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-19 07:24:24 +01:00
Cleaned so much stuff up add tracing add chat formatting
This commit is contained in:
115
Cargo.lock
generated
115
Cargo.lock
generated
@@ -138,9 +138,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.86"
|
||||
version = "1.0.88"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
|
||||
checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
@@ -622,9 +622,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-2"
|
||||
version = "0.1.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c"
|
||||
version = "0.1.34"
|
||||
dependencies = [
|
||||
"llama-cpp-sys-2",
|
||||
"thiserror",
|
||||
@@ -633,9 +631,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-sys-2"
|
||||
version = "0.1.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36"
|
||||
version = "0.1.34"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"cc",
|
||||
@@ -675,6 +671,8 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -691,9 +689,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lsp-types"
|
||||
version = "0.94.1"
|
||||
version = "0.95.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1"
|
||||
checksum = "158c1911354ef73e8fe42da6b10c0484cb65c7f1007f28022e847706c1ab6984"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"serde",
|
||||
@@ -718,6 +716,15 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
|
||||
dependencies = [
|
||||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.1"
|
||||
@@ -797,6 +804,16 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||
dependencies = [
|
||||
"overload",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.4.0"
|
||||
@@ -881,6 +898,12 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
@@ -1057,10 +1080,19 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-automata 0.4.5",
|
||||
"regex-syntax 0.8.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
|
||||
dependencies = [
|
||||
"regex-syntax 0.6.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.5"
|
||||
@@ -1072,6 +1104,12 @@ dependencies = [
|
||||
"regex-syntax 0.8.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.5"
|
||||
@@ -1245,6 +1283,15 @@ dependencies = [
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
@@ -1364,6 +1411,16 @@ dependencies = [
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.6.0"
|
||||
@@ -1441,6 +1498,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"valuable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-log"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1536,6 +1623,12 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -7,9 +7,9 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
lsp-server = "0.7.4"
|
||||
lsp-server = "0.7.6"
|
||||
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
|
||||
lsp-types = "0.94.1"
|
||||
lsp-types = "0.95.0"
|
||||
ropey = "1.6.1"
|
||||
serde = "1.0.190"
|
||||
serde_json = "1.0.108"
|
||||
@@ -19,9 +19,11 @@ tokenizers = "0.14.1"
|
||||
parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
llama-cpp-2 = "0.1.31"
|
||||
# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||
# llama-cpp-2 = "0.1.31"
|
||||
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||
minijinja = "1.0.12"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -3,8 +3,6 @@ use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::memory_backends::Prompt;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
||||
|
||||
@@ -26,7 +24,7 @@ pub enum ValidTransformerBackend {
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub message: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -37,14 +35,14 @@ pub struct Chat {
|
||||
pub chat_format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct FIM {
|
||||
pub start: String,
|
||||
pub middle: String,
|
||||
pub end: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct MaxNewTokens {
|
||||
pub completion: usize,
|
||||
pub generation: usize,
|
||||
@@ -59,7 +57,7 @@ impl Default for MaxNewTokens {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ValidMemoryConfiguration {
|
||||
file_store: Option<Value>,
|
||||
}
|
||||
@@ -72,13 +70,13 @@ impl Default for ValidMemoryConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Model {
|
||||
pub repository: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ModelGGUF {
|
||||
// The model to use
|
||||
#[serde(flatten)]
|
||||
@@ -114,7 +112,7 @@ impl Default for ModelGGUF {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ValidMacTransformerConfiguration {
|
||||
model_gguf: Option<ModelGGUF>,
|
||||
}
|
||||
@@ -127,7 +125,7 @@ impl Default for ValidMacTransformerConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ValidLinuxTransformerConfiguration {
|
||||
model_gguf: Option<ModelGGUF>,
|
||||
}
|
||||
@@ -140,7 +138,7 @@ impl Default for ValidLinuxTransformerConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Default)]
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
struct ValidConfiguration {
|
||||
memory: ValidMemoryConfiguration,
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -151,7 +149,7 @@ struct ValidConfiguration {
|
||||
transformer: ValidLinuxTransformerConfiguration,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Configuration {
|
||||
valid_config: ValidConfiguration,
|
||||
}
|
||||
@@ -175,7 +173,7 @@ impl Configuration {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
Ok(&model_gguf.model)
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +181,7 @@ impl Configuration {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
Ok(&model_gguf.kwargs)
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,9 +201,9 @@ impl Configuration {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_maximum_context_length(&self) -> usize {
|
||||
pub fn get_maximum_context_length(&self) -> Result<usize> {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
model_gguf
|
||||
Ok(model_gguf
|
||||
.kwargs
|
||||
.get("n_ctx")
|
||||
.map(|v| {
|
||||
@@ -213,91 +211,33 @@ impl Configuration {
|
||||
.map(|u| u as usize)
|
||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
||||
})
|
||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX))
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_max_new_tokens(&self) -> &MaxNewTokens {
|
||||
pub fn get_max_new_tokens(&self) -> Result<&MaxNewTokens> {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
&model_gguf.max_new_tokens
|
||||
Ok(&model_gguf.max_new_tokens)
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_fim(&self) -> Option<&FIM> {
|
||||
pub fn get_fim(&self) -> Result<Option<&FIM>> {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
model_gguf.fim.as_ref()
|
||||
Ok(model_gguf.fim.as_ref())
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_chat(&self) -> Option<&Chat> {
|
||||
pub fn get_chat(&self) -> Result<Option<&Chat>> {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
model_gguf.chat.as_ref()
|
||||
Ok(model_gguf.chat.as_ref())
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
anyhow::bail!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn custom_mac_gguf_model() {
|
||||
let args = json!({
|
||||
"initializationOptions": {
|
||||
"memory": {
|
||||
"file_store": {}
|
||||
},
|
||||
"macos": {
|
||||
"model_gguf": {
|
||||
// "repository": "deepseek-coder-6.7b-base",
|
||||
// "name": "Q4_K_M.gguf",
|
||||
"repository": "stabilityai/stablelm-2-zephyr-1_6b",
|
||||
"name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
|
||||
"max_new_tokens": {
|
||||
"completion": 32,
|
||||
"generation": 256,
|
||||
},
|
||||
"fim": {
|
||||
"start": "<fim_prefix>",
|
||||
"middle": "<fim_suffix>",
|
||||
"end": "<fim_middle>"
|
||||
},
|
||||
"chat": {
|
||||
"completion": [
|
||||
{
|
||||
"role": "system",
|
||||
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"message": "Complete the following code: \n\n{code}"
|
||||
}
|
||||
],
|
||||
"generation": [
|
||||
{
|
||||
"role": "system",
|
||||
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"message": "Complete the following code: \n\n{code}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"n_ctx": 2048,
|
||||
"n_gpu_layers": 35,
|
||||
}
|
||||
},
|
||||
}
|
||||
});
|
||||
let _ = Configuration::new(args).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
86
src/main.rs
86
src/main.rs
@@ -7,12 +7,12 @@ use lsp_types::{
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
use tracing::error;
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
|
||||
mod configuration;
|
||||
mod custom_requests;
|
||||
mod memory_backends;
|
||||
mod template;
|
||||
mod tokenizer;
|
||||
mod transformer_backends;
|
||||
mod utils;
|
||||
mod worker;
|
||||
@@ -43,6 +43,13 @@ where
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
|
||||
// If the variables value is malformed or missing, sets the default log level to ERROR
|
||||
FmtSubscriber::builder()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
|
||||
.init();
|
||||
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||
completion_provider: Some(CompletionOptions::default()),
|
||||
@@ -104,12 +111,11 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
if request_is::<Completion>(&req) {
|
||||
match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
eprintln!("******{:?}********", id);
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = CompletionRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::Completion(completion_request));
|
||||
}
|
||||
Err(err) => eprintln!("{err:?}"),
|
||||
Err(err) => error!("{err:?}"),
|
||||
}
|
||||
} else if request_is::<Generate>(&req) {
|
||||
match cast::<Generate>(req) {
|
||||
@@ -118,7 +124,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let completion_request = GenerateRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::Generate(completion_request));
|
||||
}
|
||||
Err(err) => eprintln!("{err:?}"),
|
||||
Err(err) => error!("{err:?}"),
|
||||
}
|
||||
} else if request_is::<GenerateStream>(&req) {
|
||||
match cast::<GenerateStream>(req) {
|
||||
@@ -127,10 +133,10 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let completion_request = GenerateStreamRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
|
||||
}
|
||||
Err(err) => eprintln!("{err:?}"),
|
||||
Err(err) => error!("{err:?}"),
|
||||
}
|
||||
} else {
|
||||
eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
|
||||
error!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
|
||||
}
|
||||
}
|
||||
Message::Notification(not) => {
|
||||
@@ -150,3 +156,69 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::memory_backends::Prompt;
|
||||
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn custom_mac_gguf_model() {
|
||||
let args = json!({
|
||||
"initializationOptions": {
|
||||
"memory": {
|
||||
"file_store": {}
|
||||
},
|
||||
"macos": {
|
||||
"model_gguf": {
|
||||
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
|
||||
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
|
||||
// "repository": "stabilityai/stablelm-2-zephyr-1_6b",
|
||||
// "name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
|
||||
"max_new_tokens": {
|
||||
"completion": 32,
|
||||
"generation": 256,
|
||||
},
|
||||
// "fim": {
|
||||
// "start": "<fim_prefix>",
|
||||
// "middle": "<fim_suffix>",
|
||||
// "end": "<fim_middle>"
|
||||
// },
|
||||
// "chat": {
|
||||
// "completion": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "Complete the following code: \n\n{code}"
|
||||
// }
|
||||
// ],
|
||||
// "generation": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "Complete the following code: \n\n{code}"
|
||||
// }
|
||||
// ],
|
||||
// // "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||
// },
|
||||
"n_ctx": 2048,
|
||||
"n_gpu_layers": 35,
|
||||
}
|
||||
},
|
||||
}
|
||||
});
|
||||
let configuration = Configuration::new(args).unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||
let response = backend.do_completion(&prompt).unwrap();
|
||||
eprintln!("\nRESPONSE:\n{:?}", response.insert_text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ use anyhow::Context;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use ropey::Rope;
|
||||
use std::collections::HashMap;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
|
||||
|
||||
use super::{MemoryBackend, Prompt};
|
||||
use super::{MemoryBackend, Prompt, PromptForType};
|
||||
|
||||
pub struct FileStore {
|
||||
configuration: Configuration,
|
||||
@@ -22,6 +23,7 @@ impl FileStore {
|
||||
}
|
||||
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
@@ -34,7 +36,12 @@ impl MemoryBackend for FileStore {
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt> {
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let mut rope = self
|
||||
.file_map
|
||||
.get(position.text_document.uri.as_str())
|
||||
@@ -44,15 +51,41 @@ impl MemoryBackend for FileStore {
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
|
||||
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
|
||||
let code = match self.configuration.get_fim() {
|
||||
Some(fim) if rope.len_chars() != cursor_index => {
|
||||
let max_length =
|
||||
characters_to_estimated_tokens(self.configuration.get_maximum_context_length());
|
||||
let is_chat_enabled = match prompt_for_type {
|
||||
PromptForType::Completion => self
|
||||
.configuration
|
||||
.get_chat()?
|
||||
.map(|c| c.completion.is_some())
|
||||
.unwrap_or(false),
|
||||
PromptForType::Generate => self
|
||||
.configuration
|
||||
.get_chat()?
|
||||
.map(|c| c.generation.is_some())
|
||||
.unwrap_or(false),
|
||||
};
|
||||
|
||||
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
|
||||
// and the user has not enabled chat
|
||||
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
|
||||
r @ (true, _) | r @ (false, Some(_))
|
||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||
{
|
||||
let max_length = characters_to_estimated_tokens(
|
||||
self.configuration.get_maximum_context_length()?,
|
||||
);
|
||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
|
||||
if is_chat_enabled {
|
||||
rope.insert(cursor_index, "{CURSOR}");
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end + "{CURSOR}".chars().count())
|
||||
.context("Error getting rope slice")?;
|
||||
rope_slice.to_string()
|
||||
} else {
|
||||
let fim = r.1.unwrap(); // We can unwrap as we know it is some from the match
|
||||
rope.insert(end, &fim.end);
|
||||
rope.insert(cursor_index, &fim.middle);
|
||||
rope.insert(start, &fim.start);
|
||||
@@ -67,10 +100,11 @@ impl MemoryBackend for FileStore {
|
||||
.context("Error getting rope slice")?;
|
||||
rope_slice.to_string()
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let start = cursor_index
|
||||
.checked_sub(characters_to_estimated_tokens(
|
||||
self.configuration.get_maximum_context_length(),
|
||||
self.configuration.get_maximum_context_length()?,
|
||||
))
|
||||
.unwrap_or(0);
|
||||
let rope_slice = rope
|
||||
@@ -82,6 +116,7 @@ impl MemoryBackend for FileStore {
|
||||
Ok(Prompt::new("".to_string(), code))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
@@ -92,6 +127,7 @@ impl MemoryBackend for FileStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
@@ -116,6 +152,7 @@ impl MemoryBackend for FileStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
||||
|
||||
@@ -14,11 +14,17 @@ pub struct Prompt {
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
fn new(context: String, code: String) -> Self {
|
||||
pub fn new(context: String, code: String) -> Self {
|
||||
Self { context, code }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PromptForType {
|
||||
Completion,
|
||||
Generate,
|
||||
}
|
||||
|
||||
pub trait MemoryBackend {
|
||||
fn init(&self) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
@@ -26,7 +32,11 @@ pub trait MemoryBackend {
|
||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt>;
|
||||
fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt>;
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
use crate::{
|
||||
configuration::{Chat, ChatMessage, Configuration},
|
||||
tokenizer::Tokenizer,
|
||||
};
|
||||
use hf_hub::api::sync::{Api, ApiRepo};
|
||||
|
||||
// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
|
||||
// const CHATML_BOS_TOKEN: &str = "<s>";
|
||||
// const CHATML_EOS_TOKEN: &str = "<|im_end|>";
|
||||
|
||||
// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||
// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
||||
// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = "<s>";
|
||||
// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = "</s>";
|
||||
|
||||
// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||
// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
||||
|
||||
pub struct Template {
|
||||
configuration: Configuration,
|
||||
}
|
||||
|
||||
// impl Template {
|
||||
// pub fn new(configuration: Configuration) -> Self {
|
||||
// Self { configuration }
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn apply_prompt(
|
||||
chat_messages: Vec<ChatMessage>,
|
||||
chat: &Chat,
|
||||
tokenizer: Option<&Tokenizer>,
|
||||
) -> anyhow::Result<String> {
|
||||
// If we have the chat template apply it
|
||||
// If we have the chat_format see if we have it set
|
||||
// If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file
|
||||
anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file")
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
pub struct Tokenizer {}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result<Option<Self>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,11 @@
|
||||
use anyhow::Context;
|
||||
use hf_hub::api::sync::Api;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
configuration::{Chat, Configuration},
|
||||
configuration::Configuration,
|
||||
memory_backends::Prompt,
|
||||
template::{apply_prompt, Template},
|
||||
tokenizer::Tokenizer,
|
||||
utils::format_chat_messages,
|
||||
worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
@@ -19,10 +18,10 @@ use model::Model;
|
||||
pub struct LlamaCPP {
|
||||
model: Model,
|
||||
configuration: Configuration,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
}
|
||||
|
||||
impl LlamaCPP {
|
||||
#[instrument]
|
||||
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
|
||||
let api = Api::new()?;
|
||||
let model = configuration.get_model()?;
|
||||
@@ -32,44 +31,53 @@ impl LlamaCPP {
|
||||
.context("Model `name` is required when using GGUF models")?;
|
||||
let repo = api.model(model.repository.to_owned());
|
||||
let model_path = repo.get(&name)?;
|
||||
let tokenizer: Option<Tokenizer> = Tokenizer::maybe_from_repo(repo)?;
|
||||
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
configuration,
|
||||
tokenizer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
#[instrument(skip(self))]
|
||||
fn get_prompt_string(&self, prompt: &Prompt) -> anyhow::Result<String> {
|
||||
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
|
||||
let prompt = match self.configuration.get_chat() {
|
||||
Ok(match self.configuration.get_chat()? {
|
||||
Some(c) => {
|
||||
if let Some(completion_messages) = &c.completion {
|
||||
let chat_messages = format_chat_messages(completion_messages, prompt);
|
||||
apply_prompt(chat_messages, c, self.tokenizer.as_ref())?
|
||||
self.model
|
||||
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
|
||||
} else {
|
||||
prompt.code.to_owned()
|
||||
}
|
||||
}
|
||||
None => prompt.code.to_owned(),
|
||||
};
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
||||
self.model
|
||||
.complete(&prompt, max_new_tokens)
|
||||
.map(|insert_text| DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
unimplemented!()
|
||||
// let max_new_tokens = self.configuration.get_max_new_tokens().generation;
|
||||
// self.model
|
||||
// .complete(prompt, max_new_tokens)
|
||||
// .map(|generated_text| DoGenerateResponse { generated_text })
|
||||
let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
||||
self.model
|
||||
.complete(&prompt, max_new_tokens)
|
||||
.map(|generated_text| DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate_stream(
|
||||
&self,
|
||||
_request: &GenerateStreamRequest,
|
||||
@@ -132,7 +140,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
let configuration = Configuration::new(args).unwrap();
|
||||
let model = LlamaCPP::new(configuration).unwrap();
|
||||
let _model = LlamaCPP::new(configuration).unwrap();
|
||||
// let output = model.do_completion("def fibon").unwrap();
|
||||
// println!("{}", output.insert_text);
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ use llama_cpp_2::{
|
||||
ggml_time_us,
|
||||
llama_backend::LlamaBackend,
|
||||
llama_batch::LlamaBatch,
|
||||
model::{params::LlamaModelParams, AddBos, LlamaModel},
|
||||
model::{params::LlamaModelParams, AddBos, LlamaChatMessage, LlamaModel},
|
||||
token::data_array::LlamaTokenDataArray,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
|
||||
use tracing::{debug, info, instrument};
|
||||
|
||||
use crate::configuration::Kwargs;
|
||||
use crate::configuration::{ChatMessage, Kwargs};
|
||||
|
||||
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
|
||||
|
||||
@@ -20,6 +21,7 @@ pub struct Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
#[instrument]
|
||||
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
|
||||
// Get n_gpu_layers if set in kwargs
|
||||
// As a default we set it to 1000, which should put all layers on the GPU
|
||||
@@ -41,9 +43,8 @@ impl Model {
|
||||
};
|
||||
|
||||
// Load the model
|
||||
eprintln!("SETTING MODEL AT PATH: {:?}", model_path);
|
||||
debug!("Loading model at path: {:?}", model_path);
|
||||
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
||||
eprintln!("\nMODEL SET\n");
|
||||
|
||||
// Get n_ctx if set in kwargs
|
||||
// As a default we set it to 2048
|
||||
@@ -60,6 +61,7 @@ impl Model {
|
||||
Ok(Model { model, n_ctx })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
||||
// initialize the context
|
||||
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
|
||||
@@ -77,9 +79,7 @@ impl Model {
|
||||
let n_cxt = ctx.n_ctx() as usize;
|
||||
let n_kv_req = tokens_list.len() + max_new_tokens;
|
||||
|
||||
eprintln!(
|
||||
"n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}"
|
||||
);
|
||||
info!("n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||
if n_kv_req > n_cxt {
|
||||
@@ -132,14 +132,29 @@ impl Model {
|
||||
|
||||
let t_main_end = ggml_time_us();
|
||||
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
|
||||
eprintln!(
|
||||
info!(
|
||||
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
|
||||
n_decode,
|
||||
duration.as_secs_f32(),
|
||||
n_decode as f32 / duration.as_secs_f32()
|
||||
);
|
||||
eprintln!("{}", ctx.timings());
|
||||
info!("{}", ctx.timings());
|
||||
|
||||
Ok(output.join(""))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub fn apply_chat_template(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
template: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
let llama_chat_messages = messages
|
||||
.into_iter()
|
||||
.map(|c| LlamaChatMessage::new(c.role, c.content))
|
||||
.collect::<Result<Vec<LlamaChatMessage>, _>>()?;
|
||||
Ok(self
|
||||
.model
|
||||
.apply_chat_template(template, llama_chat_messages, true)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ pub trait ToResponseError {
|
||||
impl ToResponseError for anyhow::Error {
|
||||
fn to_response_error(&self, code: i32) -> ResponseError {
|
||||
ResponseError {
|
||||
code: -32603,
|
||||
code,
|
||||
message: self.to_string(),
|
||||
data: None,
|
||||
}
|
||||
@@ -25,8 +25,8 @@ pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec
|
||||
.iter()
|
||||
.map(|m| ChatMessage {
|
||||
role: m.role.to_owned(),
|
||||
message: m
|
||||
.message
|
||||
content: m
|
||||
.content
|
||||
.replace("{context}", &prompt.context)
|
||||
.replace("{code}", &prompt.code),
|
||||
})
|
||||
|
||||
@@ -5,14 +5,15 @@ use lsp_types::{
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use crate::custom_requests::generate_stream::GenerateStreamParams;
|
||||
use crate::memory_backends::MemoryBackend;
|
||||
use crate::memory_backends::{MemoryBackend, PromptForType};
|
||||
use crate::transformer_backends::TransformerBackend;
|
||||
use crate::utils::ToResponseError;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
@@ -24,7 +25,7 @@ impl CompletionRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GenerateRequest {
|
||||
id: RequestId,
|
||||
params: GenerateParams,
|
||||
@@ -38,7 +39,7 @@ impl GenerateRequest {
|
||||
|
||||
// The generate stream is not yet ready but we don't want to remove it
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GenerateStreamRequest {
|
||||
id: RequestId,
|
||||
params: GenerateStreamParams,
|
||||
@@ -91,21 +92,17 @@ impl Worker {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self
|
||||
.memory_backend
|
||||
.lock()
|
||||
.build_prompt(&request.params.text_document_position)?;
|
||||
let prompt = self.memory_backend.lock().build_prompt(
|
||||
&request.params.text_document_position,
|
||||
PromptForType::Completion,
|
||||
)?;
|
||||
let filter_text = self
|
||||
.memory_backend
|
||||
.lock()
|
||||
.get_filter_text(&request.params.text_document_position)?;
|
||||
eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt);
|
||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
||||
eprintln!(
|
||||
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n",
|
||||
response.insert_text
|
||||
);
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
@@ -139,12 +136,12 @@ impl Worker {
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self
|
||||
.memory_backend
|
||||
.lock()
|
||||
.build_prompt(&request.params.text_document_position)?;
|
||||
eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt);
|
||||
let prompt = self.memory_backend.lock().build_prompt(
|
||||
&request.params.text_document_position,
|
||||
PromptForType::Generate,
|
||||
)?;
|
||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
|
||||
Reference in New Issue
Block a user