Cleaned so much stuff up add tracing add chat formatting

This commit is contained in:
Silas Marvin
2024-03-03 14:40:42 -08:00
parent 28b7b1b74e
commit 6627da705e
12 changed files with 353 additions and 225 deletions

115
Cargo.lock generated
View File

@@ -138,9 +138,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.86" version = "1.0.88"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@@ -622,9 +622,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]] [[package]]
name = "llama-cpp-2" name = "llama-cpp-2"
version = "0.1.31" version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c"
dependencies = [ dependencies = [
"llama-cpp-sys-2", "llama-cpp-sys-2",
"thiserror", "thiserror",
@@ -633,9 +631,7 @@ dependencies = [
[[package]] [[package]]
name = "llama-cpp-sys-2" name = "llama-cpp-sys-2"
version = "0.1.31" version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36"
dependencies = [ dependencies = [
"bindgen", "bindgen",
"cc", "cc",
@@ -675,6 +671,8 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"tokenizers", "tokenizers",
"tracing",
"tracing-subscriber",
] ]
[[package]] [[package]]
@@ -691,9 +689,9 @@ dependencies = [
[[package]] [[package]]
name = "lsp-types" name = "lsp-types"
version = "0.94.1" version = "0.95.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1" checksum = "158c1911354ef73e8fe42da6b10c0484cb65c7f1007f28022e847706c1ab6984"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"serde", "serde",
@@ -718,6 +716,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.1" version = "2.7.1"
@@ -797,6 +804,16 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]] [[package]]
name = "number_prefix" name = "number_prefix"
version = "0.4.0" version = "0.4.0"
@@ -881,6 +898,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.1" version = "0.12.1"
@@ -1057,10 +1080,19 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
"regex-automata", "regex-automata 0.4.5",
"regex-syntax 0.8.2", "regex-syntax 0.8.2",
] ]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]] [[package]]
name = "regex-automata" name = "regex-automata"
version = "0.4.5" version = "0.4.5"
@@ -1072,6 +1104,12 @@ dependencies = [
"regex-syntax 0.8.2", "regex-syntax 0.8.2",
] ]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.7.5" version = "0.7.5"
@@ -1245,6 +1283,15 @@ dependencies = [
"syn 2.0.50", "syn 2.0.50",
] ]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
@@ -1364,6 +1411,16 @@ dependencies = [
"syn 2.0.50", "syn 2.0.50",
] ]
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.6.0" version = "1.6.0"
@@ -1441,6 +1498,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
] ]
[[package]] [[package]]
@@ -1536,6 +1623,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"
version = "0.2.15" version = "0.2.15"

View File

@@ -7,9 +7,9 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.75" anyhow = "1.0.75"
lsp-server = "0.7.4" lsp-server = "0.7.6"
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" } # lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
lsp-types = "0.94.1" lsp-types = "0.95.0"
ropey = "1.6.1" ropey = "1.6.1"
serde = "1.0.190" serde = "1.0.190"
serde_json = "1.0.108" serde_json = "1.0.108"
@@ -19,9 +19,11 @@ tokenizers = "0.14.1"
parking_lot = "0.12.1" parking_lot = "0.12.1"
once_cell = "1.19.0" once_cell = "1.19.0"
directories = "5.0.1" directories = "5.0.1"
llama-cpp-2 = "0.1.31" # llama-cpp-2 = "0.1.31"
# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
minijinja = "1.0.12" minijinja = "1.0.12"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
[features] [features]
default = [] default = []

View File

@@ -3,8 +3,6 @@ use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
use crate::memory_backends::Prompt;
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024; const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
@@ -26,7 +24,7 @@ pub enum ValidTransformerBackend {
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct ChatMessage { pub struct ChatMessage {
pub role: String, pub role: String,
pub message: String, pub content: String,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
@@ -37,14 +35,14 @@ pub struct Chat {
pub chat_format: Option<String>, pub chat_format: Option<String>,
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct FIM { pub struct FIM {
pub start: String, pub start: String,
pub middle: String, pub middle: String,
pub end: String, pub end: String,
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct MaxNewTokens { pub struct MaxNewTokens {
pub completion: usize, pub completion: usize,
pub generation: usize, pub generation: usize,
@@ -59,7 +57,7 @@ impl Default for MaxNewTokens {
} }
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
struct ValidMemoryConfiguration { struct ValidMemoryConfiguration {
file_store: Option<Value>, file_store: Option<Value>,
} }
@@ -72,13 +70,13 @@ impl Default for ValidMemoryConfiguration {
} }
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Model { pub struct Model {
pub repository: String, pub repository: String,
pub name: Option<String>, pub name: Option<String>,
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
struct ModelGGUF { struct ModelGGUF {
// The model to use // The model to use
#[serde(flatten)] #[serde(flatten)]
@@ -114,7 +112,7 @@ impl Default for ModelGGUF {
} }
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
struct ValidMacTransformerConfiguration { struct ValidMacTransformerConfiguration {
model_gguf: Option<ModelGGUF>, model_gguf: Option<ModelGGUF>,
} }
@@ -127,7 +125,7 @@ impl Default for ValidMacTransformerConfiguration {
} }
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
struct ValidLinuxTransformerConfiguration { struct ValidLinuxTransformerConfiguration {
model_gguf: Option<ModelGGUF>, model_gguf: Option<ModelGGUF>,
} }
@@ -140,7 +138,7 @@ impl Default for ValidLinuxTransformerConfiguration {
} }
} }
#[derive(Clone, Deserialize, Default)] #[derive(Clone, Debug, Deserialize, Default)]
struct ValidConfiguration { struct ValidConfiguration {
memory: ValidMemoryConfiguration, memory: ValidMemoryConfiguration,
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
@@ -151,7 +149,7 @@ struct ValidConfiguration {
transformer: ValidLinuxTransformerConfiguration, transformer: ValidLinuxTransformerConfiguration,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct Configuration { pub struct Configuration {
valid_config: ValidConfiguration, valid_config: ValidConfiguration,
} }
@@ -175,7 +173,7 @@ impl Configuration {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
Ok(&model_gguf.model) Ok(&model_gguf.model)
} else { } else {
panic!("We currently only support gguf models using llama cpp") anyhow::bail!("We currently only support gguf models using llama cpp")
} }
} }
@@ -183,7 +181,7 @@ impl Configuration {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
Ok(&model_gguf.kwargs) Ok(&model_gguf.kwargs)
} else { } else {
panic!("We currently only support gguf models using llama cpp") anyhow::bail!("We currently only support gguf models using llama cpp")
} }
} }
@@ -203,9 +201,9 @@ impl Configuration {
} }
} }
pub fn get_maximum_context_length(&self) -> usize { pub fn get_maximum_context_length(&self) -> Result<usize> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf Ok(model_gguf
.kwargs .kwargs
.get("n_ctx") .get("n_ctx")
.map(|v| { .map(|v| {
@@ -213,91 +211,33 @@ impl Configuration {
.map(|u| u as usize) .map(|u| u as usize)
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX) .unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
}) })
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX) .unwrap_or(DEFAULT_LLAMA_CPP_N_CTX))
} else { } else {
panic!("We currently only support gguf models using llama cpp") anyhow::bail!("We currently only support gguf models using llama cpp")
} }
} }
pub fn get_max_new_tokens(&self) -> &MaxNewTokens { pub fn get_max_new_tokens(&self) -> Result<&MaxNewTokens> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
&model_gguf.max_new_tokens Ok(&model_gguf.max_new_tokens)
} else { } else {
panic!("We currently only support gguf models using llama cpp") anyhow::bail!("We currently only support gguf models using llama cpp")
} }
} }
pub fn get_fim(&self) -> Option<&FIM> { pub fn get_fim(&self) -> Result<Option<&FIM>> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf.fim.as_ref() Ok(model_gguf.fim.as_ref())
} else { } else {
panic!("We currently only support gguf models using llama cpp") anyhow::bail!("We currently only support gguf models using llama cpp")
} }
} }
pub fn get_chat(&self) -> Option<&Chat> { pub fn get_chat(&self) -> Result<Option<&Chat>> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf.chat.as_ref() Ok(model_gguf.chat.as_ref())
} else { } else {
panic!("We currently only support gguf models using llama cpp") anyhow::bail!("We currently only support gguf models using llama cpp")
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn custom_mac_gguf_model() {
let args = json!({
"initializationOptions": {
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
// "repository": "deepseek-coder-6.7b-base",
// "name": "Q4_K_M.gguf",
"repository": "stabilityai/stablelm-2-zephyr-1_6b",
"name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
"completion": [
{
"role": "system",
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
},
{
"role": "user",
"message": "Complete the following code: \n\n{code}"
}
],
"generation": [
{
"role": "system",
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
},
{
"role": "user",
"message": "Complete the following code: \n\n{code}"
}
]
},
"n_ctx": 2048,
"n_gpu_layers": 35,
}
},
}
});
let _ = Configuration::new(args).unwrap();
}
}

View File

@@ -7,12 +7,12 @@ use lsp_types::{
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{sync::Arc, thread}; use std::{sync::Arc, thread};
use tracing::error;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod configuration; mod configuration;
mod custom_requests; mod custom_requests;
mod memory_backends; mod memory_backends;
mod template;
mod tokenizer;
mod transformer_backends; mod transformer_backends;
mod utils; mod utils;
mod worker; mod worker;
@@ -43,6 +43,13 @@ where
} }
fn main() -> Result<()> { fn main() -> Result<()> {
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.init();
let (connection, io_threads) = Connection::stdio(); let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(&ServerCapabilities { let server_capabilities = serde_json::to_value(&ServerCapabilities {
completion_provider: Some(CompletionOptions::default()), completion_provider: Some(CompletionOptions::default()),
@@ -104,12 +111,11 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
if request_is::<Completion>(&req) { if request_is::<Completion>(&req) {
match cast::<Completion>(req) { match cast::<Completion>(req) {
Ok((id, params)) => { Ok((id, params)) => {
eprintln!("******{:?}********", id);
let mut lcr = last_worker_request.lock(); let mut lcr = last_worker_request.lock();
let completion_request = CompletionRequest::new(id, params); let completion_request = CompletionRequest::new(id, params);
*lcr = Some(WorkerRequest::Completion(completion_request)); *lcr = Some(WorkerRequest::Completion(completion_request));
} }
Err(err) => eprintln!("{err:?}"), Err(err) => error!("{err:?}"),
} }
} else if request_is::<Generate>(&req) { } else if request_is::<Generate>(&req) {
match cast::<Generate>(req) { match cast::<Generate>(req) {
@@ -118,7 +124,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let completion_request = GenerateRequest::new(id, params); let completion_request = GenerateRequest::new(id, params);
*lcr = Some(WorkerRequest::Generate(completion_request)); *lcr = Some(WorkerRequest::Generate(completion_request));
} }
Err(err) => eprintln!("{err:?}"), Err(err) => error!("{err:?}"),
} }
} else if request_is::<GenerateStream>(&req) { } else if request_is::<GenerateStream>(&req) {
match cast::<GenerateStream>(req) { match cast::<GenerateStream>(req) {
@@ -127,10 +133,10 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let completion_request = GenerateStreamRequest::new(id, params); let completion_request = GenerateStreamRequest::new(id, params);
*lcr = Some(WorkerRequest::GenerateStream(completion_request)); *lcr = Some(WorkerRequest::GenerateStream(completion_request));
} }
Err(err) => eprintln!("{err:?}"), Err(err) => error!("{err:?}"),
} }
} else { } else {
eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream") error!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
} }
} }
Message::Notification(not) => { Message::Notification(not) => {
@@ -150,3 +156,69 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
} }
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use crate::memory_backends::Prompt;
use super::*;
use serde_json::json;
#[test]
fn custom_mac_gguf_model() {
let args = json!({
"initializationOptions": {
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
// "repository": "stabilityai/stablelm-2-zephyr-1_6b",
// "name": "stablelm-2-zephyr-1_6b-Q5_K_M.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
// "fim": {
// "start": "<fim_prefix>",
// "middle": "<fim_suffix>",
// "end": "<fim_middle>"
// },
// "chat": {
// "completion": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
// "generation": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
// // "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
// },
"n_ctx": 2048,
"n_gpu_layers": 35,
}
},
}
});
let configuration = Configuration::new(args).unwrap();
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
let response = backend.do_completion(&prompt).unwrap();
eprintln!("\nRESPONSE:\n{:?}", response.insert_text);
}
}

View File

@@ -2,10 +2,11 @@ use anyhow::Context;
use lsp_types::TextDocumentPositionParams; use lsp_types::TextDocumentPositionParams;
use ropey::Rope; use ropey::Rope;
use std::collections::HashMap; use std::collections::HashMap;
use tracing::instrument;
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens}; use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
use super::{MemoryBackend, Prompt}; use super::{MemoryBackend, Prompt, PromptForType};
pub struct FileStore { pub struct FileStore {
configuration: Configuration, configuration: Configuration,
@@ -22,6 +23,7 @@ impl FileStore {
} }
impl MemoryBackend for FileStore { impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> { fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self let rope = self
.file_map .file_map
@@ -34,7 +36,12 @@ impl MemoryBackend for FileStore {
.to_string()) .to_string())
} }
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt> { #[instrument(skip(self))]
fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
let mut rope = self let mut rope = self
.file_map .file_map
.get(position.text_document.uri.as_str()) .get(position.text_document.uri.as_str())
@@ -44,33 +51,60 @@ impl MemoryBackend for FileStore {
let cursor_index = rope.line_to_char(position.position.line as usize) let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize; + position.position.character as usize;
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file let is_chat_enabled = match prompt_for_type {
let code = match self.configuration.get_fim() { PromptForType::Completion => self
Some(fim) if rope.len_chars() != cursor_index => { .configuration
let max_length = .get_chat()?
characters_to_estimated_tokens(self.configuration.get_maximum_context_length()); .map(|c| c.completion.is_some())
.unwrap_or(false),
PromptForType::Generate => self
.configuration
.get_chat()?
.map(|c| c.generation.is_some())
.unwrap_or(false),
};
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
// and the user has not enabled chat
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
r @ (true, _) | r @ (false, Some(_))
if is_chat_enabled || rope.len_chars() != cursor_index =>
{
let max_length = characters_to_estimated_tokens(
self.configuration.get_maximum_context_length()?,
);
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0); let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let end = rope let end = rope
.len_chars() .len_chars()
.min(cursor_index + (max_length - (cursor_index - start))); .min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(end, &fim.end);
rope.insert(cursor_index, &fim.middle); if is_chat_enabled {
rope.insert(start, &fim.start); rope.insert(cursor_index, "{CURSOR}");
let rope_slice = rope let rope_slice = rope
.get_slice( .get_slice(start..end + "{CURSOR}".chars().count())
start .context("Error getting rope slice")?;
..end rope_slice.to_string()
+ fim.start.chars().count() } else {
+ fim.middle.chars().count() let fim = r.1.unwrap(); // We can unwrap as we know it is some from the match
+ fim.end.chars().count(), rope.insert(end, &fim.end);
) rope.insert(cursor_index, &fim.middle);
.context("Error getting rope slice")?; rope.insert(start, &fim.start);
rope_slice.to_string() let rope_slice = rope
.get_slice(
start
..end
+ fim.start.chars().count()
+ fim.middle.chars().count()
+ fim.end.chars().count(),
)
.context("Error getting rope slice")?;
rope_slice.to_string()
}
} }
_ => { _ => {
let start = cursor_index let start = cursor_index
.checked_sub(characters_to_estimated_tokens( .checked_sub(characters_to_estimated_tokens(
self.configuration.get_maximum_context_length(), self.configuration.get_maximum_context_length()?,
)) ))
.unwrap_or(0); .unwrap_or(0);
let rope_slice = rope let rope_slice = rope
@@ -82,6 +116,7 @@ impl MemoryBackend for FileStore {
Ok(Prompt::new("".to_string(), code)) Ok(Prompt::new("".to_string(), code))
} }
#[instrument(skip(self))]
fn opened_text_document( fn opened_text_document(
&mut self, &mut self,
params: lsp_types::DidOpenTextDocumentParams, params: lsp_types::DidOpenTextDocumentParams,
@@ -92,6 +127,7 @@ impl MemoryBackend for FileStore {
Ok(()) Ok(())
} }
#[instrument(skip(self))]
fn changed_text_document( fn changed_text_document(
&mut self, &mut self,
params: lsp_types::DidChangeTextDocumentParams, params: lsp_types::DidChangeTextDocumentParams,
@@ -116,6 +152,7 @@ impl MemoryBackend for FileStore {
Ok(()) Ok(())
} }
#[instrument(skip(self))]
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files { for file_rename in params.files {
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) { if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {

View File

@@ -14,11 +14,17 @@ pub struct Prompt {
} }
impl Prompt { impl Prompt {
fn new(context: String, code: String) -> Self { pub fn new(context: String, code: String) -> Self {
Self { context, code } Self { context, code }
} }
} }
#[derive(Debug)]
pub enum PromptForType {
Completion,
Generate,
}
pub trait MemoryBackend { pub trait MemoryBackend {
fn init(&self) -> anyhow::Result<()> { fn init(&self) -> anyhow::Result<()> {
Ok(()) Ok(())
@@ -26,7 +32,11 @@ pub trait MemoryBackend {
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>; fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt>; fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt>;
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>; fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
} }

View File

@@ -1,39 +0,0 @@
use crate::{
configuration::{Chat, ChatMessage, Configuration},
tokenizer::Tokenizer,
};
use hf_hub::api::sync::{Api, ApiRepo};
// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
// const CHATML_BOS_TOKEN: &str = "<s>";
// const CHATML_EOS_TOKEN: &str = "<|im_end|>";
// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = "<s>";
// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = "</s>";
// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
pub struct Template {
configuration: Configuration,
}
// impl Template {
// pub fn new(configuration: Configuration) -> Self {
// Self { configuration }
// }
// }
pub fn apply_prompt(
chat_messages: Vec<ChatMessage>,
chat: &Chat,
tokenizer: Option<&Tokenizer>,
) -> anyhow::Result<String> {
// If we have the chat template apply it
// If we have the chat_format see if we have it set
// If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file
anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file")
}

View File

@@ -1,7 +0,0 @@
pub struct Tokenizer {}
impl Tokenizer {
pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result<Option<Self>> {
unimplemented!()
}
}

View File

@@ -1,12 +1,11 @@
use anyhow::Context; use anyhow::Context;
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
use tracing::{debug, instrument};
use super::TransformerBackend; use super::TransformerBackend;
use crate::{ use crate::{
configuration::{Chat, Configuration}, configuration::Configuration,
memory_backends::Prompt, memory_backends::Prompt,
template::{apply_prompt, Template},
tokenizer::Tokenizer,
utils::format_chat_messages, utils::format_chat_messages,
worker::{ worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
@@ -19,10 +18,10 @@ use model::Model;
pub struct LlamaCPP { pub struct LlamaCPP {
model: Model, model: Model,
configuration: Configuration, configuration: Configuration,
tokenizer: Option<Tokenizer>,
} }
impl LlamaCPP { impl LlamaCPP {
#[instrument]
pub fn new(configuration: Configuration) -> anyhow::Result<Self> { pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
let api = Api::new()?; let api = Api::new()?;
let model = configuration.get_model()?; let model = configuration.get_model()?;
@@ -32,44 +31,53 @@ impl LlamaCPP {
.context("Model `name` is required when using GGUF models")?; .context("Model `name` is required when using GGUF models")?;
let repo = api.model(model.repository.to_owned()); let repo = api.model(model.repository.to_owned());
let model_path = repo.get(&name)?; let model_path = repo.get(&name)?;
let tokenizer: Option<Tokenizer> = Tokenizer::maybe_from_repo(repo)?;
let model = Model::new(model_path, configuration.get_model_kwargs()?)?; let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
Ok(Self { Ok(Self {
model, model,
configuration, configuration,
tokenizer,
}) })
} }
}
impl TransformerBackend for LlamaCPP { #[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> { fn get_prompt_string(&self, prompt: &Prompt) -> anyhow::Result<String> {
// We need to check that they not only set the `chat` key, but they set the `completion` sub key // We need to check that they not only set the `chat` key, but they set the `completion` sub key
let prompt = match self.configuration.get_chat() { Ok(match self.configuration.get_chat()? {
Some(c) => { Some(c) => {
if let Some(completion_messages) = &c.completion { if let Some(completion_messages) = &c.completion {
let chat_messages = format_chat_messages(completion_messages, prompt); let chat_messages = format_chat_messages(completion_messages, prompt);
apply_prompt(chat_messages, c, self.tokenizer.as_ref())? self.model
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
} else { } else {
prompt.code.to_owned() prompt.code.to_owned()
} }
} }
None => prompt.code.to_owned(), None => prompt.code.to_owned(),
}; })
let max_new_tokens = self.configuration.get_max_new_tokens().completion; }
}
impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model self.model
.complete(&prompt, max_new_tokens) .complete(&prompt, max_new_tokens)
.map(|insert_text| DoCompletionResponse { insert_text }) .map(|insert_text| DoCompletionResponse { insert_text })
} }
#[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> { fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
unimplemented!() let prompt = self.get_prompt_string(prompt)?;
// let max_new_tokens = self.configuration.get_max_new_tokens().generation; // debug!("Prompt string for LLM: {}", prompt);
// self.model let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
// .complete(prompt, max_new_tokens) self.model
// .map(|generated_text| DoGenerateResponse { generated_text }) .complete(&prompt, max_new_tokens)
.map(|generated_text| DoGenerateResponse { generated_text })
} }
#[instrument(skip(self))]
fn do_generate_stream( fn do_generate_stream(
&self, &self,
_request: &GenerateStreamRequest, _request: &GenerateStreamRequest,
@@ -132,7 +140,7 @@ mod tests {
} }
}); });
let configuration = Configuration::new(args).unwrap(); let configuration = Configuration::new(args).unwrap();
let model = LlamaCPP::new(configuration).unwrap(); let _model = LlamaCPP::new(configuration).unwrap();
// let output = model.do_completion("def fibon").unwrap(); // let output = model.do_completion("def fibon").unwrap();
// println!("{}", output.insert_text); // println!("{}", output.insert_text);
} }

View File

@@ -4,13 +4,14 @@ use llama_cpp_2::{
ggml_time_us, ggml_time_us,
llama_backend::LlamaBackend, llama_backend::LlamaBackend,
llama_batch::LlamaBatch, llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel}, model::{params::LlamaModelParams, AddBos, LlamaChatMessage, LlamaModel},
token::data_array::LlamaTokenDataArray, token::data_array::LlamaTokenDataArray,
}; };
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::{num::NonZeroU32, path::PathBuf, time::Duration}; use std::{num::NonZeroU32, path::PathBuf, time::Duration};
use tracing::{debug, info, instrument};
use crate::configuration::Kwargs; use crate::configuration::{ChatMessage, Kwargs};
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap()); static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
@@ -20,6 +21,7 @@ pub struct Model {
} }
impl Model { impl Model {
#[instrument]
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> { pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
// Get n_gpu_layers if set in kwargs // Get n_gpu_layers if set in kwargs
// As a default we set it to 1000, which should put all layers on the GPU // As a default we set it to 1000, which should put all layers on the GPU
@@ -41,9 +43,8 @@ impl Model {
}; };
// Load the model // Load the model
eprintln!("SETTING MODEL AT PATH: {:?}", model_path); debug!("Loading model at path: {:?}", model_path);
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?; let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
eprintln!("\nMODEL SET\n");
// Get n_ctx if set in kwargs // Get n_ctx if set in kwargs
// As a default we set it to 2048 // As a default we set it to 2048
@@ -60,6 +61,7 @@ impl Model {
Ok(Model { model, n_ctx }) Ok(Model { model, n_ctx })
} }
#[instrument(skip(self))]
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> { pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
// initialize the context // initialize the context
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone())); let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
@@ -77,9 +79,7 @@ impl Model {
let n_cxt = ctx.n_ctx() as usize; let n_cxt = ctx.n_ctx() as usize;
let n_kv_req = tokens_list.len() + max_new_tokens; let n_kv_req = tokens_list.len() + max_new_tokens;
eprintln!( info!("n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
"n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}"
);
// make sure the KV cache is big enough to hold all the prompt and generated tokens // make sure the KV cache is big enough to hold all the prompt and generated tokens
if n_kv_req > n_cxt { if n_kv_req > n_cxt {
@@ -132,14 +132,29 @@ impl Model {
let t_main_end = ggml_time_us(); let t_main_end = ggml_time_us();
let duration = Duration::from_micros((t_main_end - t_main_start) as u64); let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
eprintln!( info!(
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n", "decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
n_decode, n_decode,
duration.as_secs_f32(), duration.as_secs_f32(),
n_decode as f32 / duration.as_secs_f32() n_decode as f32 / duration.as_secs_f32()
); );
eprintln!("{}", ctx.timings()); info!("{}", ctx.timings());
Ok(output.join("")) Ok(output.join(""))
} }
#[instrument(skip(self))]
pub fn apply_chat_template(
&self,
messages: Vec<ChatMessage>,
template: Option<String>,
) -> anyhow::Result<String> {
let llama_chat_messages = messages
.into_iter()
.map(|c| LlamaChatMessage::new(c.role, c.content))
.collect::<Result<Vec<LlamaChatMessage>, _>>()?;
Ok(self
.model
.apply_chat_template(template, llama_chat_messages, true)?)
}
} }

View File

@@ -9,7 +9,7 @@ pub trait ToResponseError {
impl ToResponseError for anyhow::Error { impl ToResponseError for anyhow::Error {
fn to_response_error(&self, code: i32) -> ResponseError { fn to_response_error(&self, code: i32) -> ResponseError {
ResponseError { ResponseError {
code: -32603, code,
message: self.to_string(), message: self.to_string(),
data: None, data: None,
} }
@@ -25,8 +25,8 @@ pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec
.iter() .iter()
.map(|m| ChatMessage { .map(|m| ChatMessage {
role: m.role.to_owned(), role: m.role.to_owned(),
message: m content: m
.message .content
.replace("{context}", &prompt.context) .replace("{context}", &prompt.context)
.replace("{code}", &prompt.code), .replace("{code}", &prompt.code),
}) })

View File

@@ -5,14 +5,15 @@ use lsp_types::{
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{sync::Arc, thread}; use std::{sync::Arc, thread};
use tracing::instrument;
use crate::custom_requests::generate::{GenerateParams, GenerateResult}; use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::GenerateStreamParams; use crate::custom_requests::generate_stream::GenerateStreamParams;
use crate::memory_backends::MemoryBackend; use crate::memory_backends::{MemoryBackend, PromptForType};
use crate::transformer_backends::TransformerBackend; use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError; use crate::utils::ToResponseError;
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct CompletionRequest { pub struct CompletionRequest {
id: RequestId, id: RequestId,
params: CompletionParams, params: CompletionParams,
@@ -24,7 +25,7 @@ impl CompletionRequest {
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct GenerateRequest { pub struct GenerateRequest {
id: RequestId, id: RequestId,
params: GenerateParams, params: GenerateParams,
@@ -38,7 +39,7 @@ impl GenerateRequest {
// The generate stream is not yet ready but we don't want to remove it // The generate stream is not yet ready but we don't want to remove it
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct GenerateStreamRequest { pub struct GenerateStreamRequest {
id: RequestId, id: RequestId,
params: GenerateStreamParams, params: GenerateStreamParams,
@@ -91,21 +92,17 @@ impl Worker {
} }
} }
#[instrument(skip(self))]
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> { fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
let prompt = self let prompt = self.memory_backend.lock().build_prompt(
.memory_backend &request.params.text_document_position,
.lock() PromptForType::Completion,
.build_prompt(&request.params.text_document_position)?; )?;
let filter_text = self let filter_text = self
.memory_backend .memory_backend
.lock() .lock()
.get_filter_text(&request.params.text_document_position)?; .get_filter_text(&request.params.text_document_position)?;
eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt);
let response = self.transformer_backend.do_completion(&prompt)?; let response = self.transformer_backend.do_completion(&prompt)?;
eprintln!(
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n",
response.insert_text
);
let completion_text_edit = TextEdit::new( let completion_text_edit = TextEdit::new(
Range::new( Range::new(
Position::new( Position::new(
@@ -139,12 +136,12 @@ impl Worker {
}) })
} }
#[instrument(skip(self))]
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> { fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
let prompt = self let prompt = self.memory_backend.lock().build_prompt(
.memory_backend &request.params.text_document_position,
.lock() PromptForType::Generate,
.build_prompt(&request.params.text_document_position)?; )?;
eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt);
let response = self.transformer_backend.do_generate(&prompt)?; let response = self.transformer_backend.do_generate(&prompt)?;
let result = GenerateResult { let result = GenerateResult {
generated_text: response.generated_text, generated_text: response.generated_text,