mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 15:04:29 +01:00
Checkpoint
This commit is contained in:
93
Cargo.lock
generated
93
Cargo.lock
generated
@@ -110,9 +110,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.80"
|
||||
version = "1.0.81"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
|
||||
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
||||
|
||||
[[package]]
|
||||
name = "assert_cmd"
|
||||
@@ -131,9 +131,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.77"
|
||||
version = "0.1.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
|
||||
checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -149,16 +149,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-write-file"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8204db279bf648d64fe845bd8840f78b39c8132ed4d6a4194c3b10d4b4cfb0b"
|
||||
dependencies = [
|
||||
"nix",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
@@ -532,6 +522,16 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctrlc"
|
||||
version = "3.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345"
|
||||
dependencies = [
|
||||
"nix",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.14.4"
|
||||
@@ -1410,6 +1410,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
[[package]]
|
||||
name = "llama-cpp-2"
|
||||
version = "0.1.34"
|
||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||
dependencies = [
|
||||
"llama-cpp-sys-2",
|
||||
"thiserror",
|
||||
@@ -1419,6 +1420,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "llama-cpp-sys-2"
|
||||
version = "0.1.34"
|
||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"cc",
|
||||
@@ -1465,6 +1467,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"async-trait",
|
||||
"directories",
|
||||
"hf-hub",
|
||||
"ignore",
|
||||
@@ -1956,6 +1959,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"colored",
|
||||
"ctrlc",
|
||||
"futures",
|
||||
"indicatif",
|
||||
"inquire",
|
||||
@@ -2075,9 +2079,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.78"
|
||||
version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
|
||||
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
@@ -2224,9 +2228,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.25"
|
||||
version = "0.11.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946"
|
||||
checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
@@ -2769,9 +2773,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf"
|
||||
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
|
||||
dependencies = [
|
||||
"sqlx-core",
|
||||
"sqlx-macros",
|
||||
@@ -2782,9 +2786,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-core"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd"
|
||||
checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"atoi",
|
||||
@@ -2792,7 +2796,6 @@ dependencies = [
|
||||
"bytes",
|
||||
"crc",
|
||||
"crossbeam-queue",
|
||||
"dotenvy",
|
||||
"either",
|
||||
"event-listener",
|
||||
"futures-channel",
|
||||
@@ -2827,9 +2830,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5"
|
||||
checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2840,11 +2843,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros-core"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841"
|
||||
checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
|
||||
dependencies = [
|
||||
"atomic-write-file",
|
||||
"dotenvy",
|
||||
"either",
|
||||
"heck",
|
||||
@@ -2867,9 +2869,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-mysql"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4"
|
||||
checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.21.7",
|
||||
@@ -2911,9 +2913,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-postgres"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24"
|
||||
checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.21.7",
|
||||
@@ -2938,7 +2940,6 @@ dependencies = [
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core",
|
||||
@@ -2952,9 +2953,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-sqlite"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490"
|
||||
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"flume",
|
||||
@@ -3051,20 +3052,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.6.0"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
dependencies = [
|
||||
"bitflags 2.4.2",
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.6.0"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
|
||||
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
@@ -3090,18 +3091,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.57"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
|
||||
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.57"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
|
||||
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -3631,9 +3632,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.5.0"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e"
|
||||
checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9"
|
||||
dependencies = [
|
||||
"redox_syscall",
|
||||
"wasite",
|
||||
|
||||
@@ -20,7 +20,7 @@ parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
# llama-cpp-2 = "0.1.31"
|
||||
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
@@ -30,6 +30,7 @@ ignore = "0.4.22"
|
||||
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
|
||||
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
|
||||
indexmap = "2.2.5"
|
||||
async-trait = "0.1.78"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -129,10 +129,6 @@ const fn openai_top_p_default() -> f32 {
|
||||
0.95
|
||||
}
|
||||
|
||||
const fn openai_top_k_default() -> usize {
|
||||
40
|
||||
}
|
||||
|
||||
const fn openai_presence_penalty() -> f32 {
|
||||
0.
|
||||
}
|
||||
@@ -155,7 +151,9 @@ pub struct OpenAI {
|
||||
pub auth_token_env_var_name: Option<String>,
|
||||
pub auth_token: Option<String>,
|
||||
// The completions endpoint
|
||||
pub completions_endpoint: String,
|
||||
pub completions_endpoint: Option<String>,
|
||||
// The chat endpoint
|
||||
pub chat_endpoint: Option<String>,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// Fill in the middle support
|
||||
@@ -168,8 +166,6 @@ pub struct OpenAI {
|
||||
// Other available args
|
||||
#[serde(default = "openai_top_p_default")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "openai_top_k_default")]
|
||||
pub top_k: usize,
|
||||
#[serde(default = "openai_presence_penalty")]
|
||||
pub presence_penalty: f32,
|
||||
#[serde(default = "openai_frequency_penalty")]
|
||||
|
||||
26
src/main.rs
26
src/main.rs
@@ -52,7 +52,7 @@ fn main() -> Result<()> {
|
||||
.init();
|
||||
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||
let server_capabilities = serde_json::to_value(ServerCapabilities {
|
||||
completion_provider: Some(CompletionOptions::default()),
|
||||
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
|
||||
TextDocumentSyncKind::INCREMENTAL,
|
||||
@@ -77,7 +77,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
|
||||
|
||||
// Set the memory_backend
|
||||
let memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>> =
|
||||
let memory_backend: Box<dyn MemoryBackend + Send> =
|
||||
Arc::new(Mutex::new(args.clone().try_into()?));
|
||||
|
||||
// Wrap the connection for sharing between threads
|
||||
@@ -87,6 +87,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let last_worker_request = Arc::new(Mutex::new(None));
|
||||
|
||||
// Thread local variables
|
||||
// TODO: Setup some kind of handler for errors here
|
||||
let thread_memory_backend = memory_backend.clone();
|
||||
let thread_last_worker_request = last_worker_request.clone();
|
||||
let thread_connection = connection.clone();
|
||||
@@ -97,7 +98,8 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
thread_last_worker_request,
|
||||
thread_connection,
|
||||
)
|
||||
.run();
|
||||
.run()
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
for msg in &connection.receiver {
|
||||
@@ -143,13 +145,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
Message::Notification(not) => {
|
||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
memory_backend.lock().opened_text_document(params)?;
|
||||
// memory_backend.lock().opened_text_document(params)?;
|
||||
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
memory_backend.lock().changed_text_document(params)?;
|
||||
// memory_backend.lock().changed_text_document(params)?;
|
||||
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||
memory_backend.lock().renamed_file(params)?;
|
||||
// memory_backend.lock().renamed_file(params)?;
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
@@ -170,18 +172,18 @@ mod tests {
|
||||
//////////////////////////////////////
|
||||
//////////////////////////////////////
|
||||
|
||||
#[test]
|
||||
fn completion_with_default_arguments() {
|
||||
#[tokio::test]
|
||||
async fn completion_with_default_arguments() {
|
||||
let args = json!({});
|
||||
let configuration = Configuration::new(args).unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||
let response = backend.do_completion(&prompt).unwrap();
|
||||
let response = backend.do_completion(&prompt).await.unwrap();
|
||||
assert!(!response.insert_text.is_empty())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn completion_with_custom_gguf_model() {
|
||||
#[tokio::test]
|
||||
async fn completion_with_custom_gguf_model() {
|
||||
let args = json!({
|
||||
"initializationOptions": {
|
||||
"memory": {
|
||||
@@ -232,7 +234,7 @@ mod tests {
|
||||
let configuration = Configuration::new(args).unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||
let response = backend.do_completion(&prompt).unwrap();
|
||||
let response = backend.do_completion(&prompt).await.unwrap();
|
||||
assert!(!response.insert_text.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ impl FileStore {
|
||||
.iter()
|
||||
.filter(|f| **f != current_document_uri)
|
||||
{
|
||||
let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0);
|
||||
let needed = characters.saturating_sub(rope.len_chars());
|
||||
if needed == 0 {
|
||||
break;
|
||||
}
|
||||
@@ -99,7 +99,7 @@ impl FileStore {
|
||||
.clone();
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
|
||||
let start = cursor_index.saturating_sub(characters / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (characters - (cursor_index - start)));
|
||||
@@ -137,15 +137,15 @@ impl FileStore {
|
||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||
{
|
||||
let max_length = tokens_to_estimated_characters(max_context_length);
|
||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
|
||||
if is_chat_enabled {
|
||||
rope.insert(cursor_index, "{CURSOR}");
|
||||
rope.insert(cursor_index, "<CURSOR>");
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end + "{CURSOR}".chars().count())
|
||||
.get_slice(start..end + "<CURSOR>".chars().count())
|
||||
.context("Error getting rope slice")?;
|
||||
rope_slice.to_string()
|
||||
} else {
|
||||
@@ -166,9 +166,8 @@ impl FileStore {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let start = cursor_index
|
||||
.checked_sub(tokens_to_estimated_characters(max_context_length))
|
||||
.unwrap_or(0);
|
||||
let start =
|
||||
cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
@@ -178,9 +177,13 @@ impl FileStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.get(position.text_document.uri.as_str())
|
||||
@@ -193,7 +196,7 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
async fn build_prompt(
|
||||
&mut self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
@@ -207,7 +210,7 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
async fn opened_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -219,7 +222,7 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
async fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -246,7 +249,7 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
||||
self.file_map.insert(file_rename.new_uri, rope);
|
||||
|
||||
@@ -26,19 +26,29 @@ pub enum PromptForType {
|
||||
Generate,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait MemoryBackend {
|
||||
fn init(&self) -> anyhow::Result<()> {
|
||||
async fn init(&self) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn build_prompt(
|
||||
async fn opened_text_document(
|
||||
&mut self,
|
||||
params: DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()>;
|
||||
async fn changed_text_document(
|
||||
&mut self,
|
||||
params: DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()>;
|
||||
async fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
async fn build_prompt(
|
||||
&mut self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt>;
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||
@@ -55,3 +65,15 @@ impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This makes testing much easier. Every transformer backend takes in a prompt. When verifying they work, its
|
||||
// easier to just pass in a default prompt.
|
||||
#[cfg(test)]
|
||||
impl Prompt {
|
||||
pub fn default_with_cursor() -> Self {
|
||||
Self {
|
||||
context: r#"def test_context():\n pass"#.to_string(),
|
||||
code: r#"def test_code():\n <CURSOR>"#.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ impl PostgresML {
|
||||
};
|
||||
// TODO: Think on the naming of the collection
|
||||
// Maybe filter on metadata or I'm not sure
|
||||
let collection = Collection::new("test-lsp-ai-2", Some(database_url))?;
|
||||
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
|
||||
// TODO: Review the pipeline
|
||||
let pipeline = Pipeline::new(
|
||||
"v1",
|
||||
@@ -50,7 +50,7 @@ impl PostgresML {
|
||||
"splitter": {
|
||||
"model": "recursive_character",
|
||||
"parameters": {
|
||||
"chunk_size": 512,
|
||||
"chunk_size": 1500,
|
||||
"chunk_overlap": 40
|
||||
}
|
||||
},
|
||||
@@ -90,7 +90,7 @@ impl PostgresML {
|
||||
.into_iter()
|
||||
.map(|path| {
|
||||
let text = std::fs::read_to_string(&path)
|
||||
.expect(format!("Error reading path: {}", path).as_str());
|
||||
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
|
||||
json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
@@ -121,24 +121,28 @@ impl PostgresML {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for PostgresML {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position)
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
async fn build_prompt(
|
||||
&mut self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
// This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform
|
||||
let query = self
|
||||
.file_store
|
||||
.get_characters_around_position(position, 512)?;
|
||||
let res = self.runtime.block_on(
|
||||
self.collection.vector_search(
|
||||
let res = self
|
||||
.collection
|
||||
.vector_search(
|
||||
json!({
|
||||
"query": {
|
||||
"fields": {
|
||||
@@ -151,8 +155,8 @@ impl MemoryBackend for PostgresML {
|
||||
})
|
||||
.into(),
|
||||
&mut self.pipeline,
|
||||
),
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
let context = res
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
@@ -176,7 +180,7 @@ impl MemoryBackend for PostgresML {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
async fn opened_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -185,68 +189,63 @@ impl MemoryBackend for PostgresML {
|
||||
let task_added_pipeline = self.added_pipeline;
|
||||
let mut task_collection = self.collection.clone();
|
||||
let mut task_pipeline = self.pipeline.clone();
|
||||
self.runtime.spawn(async move {
|
||||
if !task_added_pipeline {
|
||||
task_collection
|
||||
.add_pipeline(&mut task_pipeline)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
if !task_added_pipeline {
|
||||
task_collection
|
||||
.add_pipeline(&mut task_pipeline)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error upserting documents");
|
||||
self.file_store.opened_text_document(params).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
self.debounce_tx.send(path)?;
|
||||
self.file_store.changed_text_document(params).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
let mut task_collection = self.collection.clone();
|
||||
let task_params = params.clone();
|
||||
for file in task_params.files {
|
||||
task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"id": file.old_uri
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error deleting file");
|
||||
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"id": file.new_uri,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error upserting documents");
|
||||
});
|
||||
self.file_store.opened_text_document(params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
self.debounce_tx.send(path)?;
|
||||
self.file_store.changed_text_document(params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
let mut task_collection = self.collection.clone();
|
||||
let task_params = params.clone();
|
||||
self.runtime.spawn(async move {
|
||||
for file in task_params.files {
|
||||
task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"id": file.old_uri
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error deleting file");
|
||||
let text =
|
||||
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": file.new_uri,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
});
|
||||
self.file_store.renamed_file(params)
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
self.file_store.renamed_file(params).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use anyhow::Context;
|
||||
use hf_hub::api::sync::ApiBuilder;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
configuration::{self},
|
||||
memory_backends::Prompt,
|
||||
@@ -16,6 +15,8 @@ use crate::{
|
||||
mod model;
|
||||
use model::Model;
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct LlamaCPP {
|
||||
model: Model,
|
||||
configuration: configuration::ModelGGUF,
|
||||
@@ -62,9 +63,10 @@ impl LlamaCPP {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
let prompt = &prompt.code;
|
||||
debug!("Prompt string for LLM: {}", prompt);
|
||||
@@ -75,7 +77,7 @@ impl TransformerBackend for LlamaCPP {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
let prompt = &prompt.code;
|
||||
@@ -86,7 +88,7 @@ impl TransformerBackend for LlamaCPP {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate_stream(
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
_request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
|
||||
@@ -9,11 +9,11 @@ use crate::{
|
||||
mod llama_cpp;
|
||||
mod openai;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait TransformerBackend {
|
||||
// Should all take an enum of chat messages or just a string for completion
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
||||
fn do_generate_stream(
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse>;
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
// Something more about what this file is
|
||||
// NOTE: When decoding responses from OpenAI compatbile services, we don't care about every field
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::instrument;
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
configuration,
|
||||
configuration::{self, ChatMessage},
|
||||
memory_backends::Prompt,
|
||||
utils::{format_chat_messages, format_context_code},
|
||||
worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
};
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct OpenAI {
|
||||
configuration: configuration::OpenAI,
|
||||
}
|
||||
@@ -22,7 +28,19 @@ struct OpenAICompletionsChoice {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAICompletionsResponse {
|
||||
choices: Vec<OpenAICompletionsChoice>,
|
||||
choices: Option<Vec<OpenAICompletionsChoice>>,
|
||||
error: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIChatChoices {
|
||||
message: ChatMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIChatResponse {
|
||||
choices: Option<Vec<OpenAIChatChoices>>,
|
||||
error: Option<Value>,
|
||||
}
|
||||
|
||||
impl OpenAI {
|
||||
@@ -42,7 +60,12 @@ impl OpenAI {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
let res: OpenAICompletionsResponse = client
|
||||
.post(&self.configuration.completions_endpoint)
|
||||
.post(
|
||||
self.configuration
|
||||
.completions_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
@@ -51,7 +74,6 @@ impl OpenAI {
|
||||
"max_tokens": max_tokens,
|
||||
"n": 1,
|
||||
"top_p": self.configuration.top_p,
|
||||
"top_k": self.configuration.top_k,
|
||||
"presence_penalty": self.configuration.presence_penalty,
|
||||
"frequency_penalty": self.configuration.frequency_penalty,
|
||||
"temperature": self.configuration.temperature,
|
||||
@@ -60,34 +82,219 @@ impl OpenAI {
|
||||
}))
|
||||
.send()?
|
||||
.json()?;
|
||||
eprintln!("**********RECEIVED REQUEST********");
|
||||
Ok(res.choices[0].text.clone())
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(choices) = res.choices {
|
||||
Ok(choices[0].text.clone())
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_chat(&self, messages: Vec<ChatMessage>, max_tokens: usize) -> anyhow::Result<String> {
|
||||
eprintln!(
|
||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||
messages
|
||||
);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
let res: OpenAIChatResponse = client
|
||||
.post(
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&json!({
|
||||
"model": self.configuration.model,
|
||||
"max_tokens": max_tokens,
|
||||
"n": 1,
|
||||
"top_p": self.configuration.top_p,
|
||||
"presence_penalty": self.configuration.presence_penalty,
|
||||
"frequency_penalty": self.configuration.frequency_penalty,
|
||||
"temperature": self.configuration.temperature,
|
||||
"messages": messages
|
||||
}))
|
||||
.send()?
|
||||
.json()?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(choices) = res.choices {
|
||||
Ok(choices[0].message.content.clone())
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for OpenAI {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
||||
let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
||||
let max_tokens = self.configuration.max_tokens.completion;
|
||||
let insert_text = match &self.configuration.chat {
|
||||
Some(c) => match &c.completion {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, max_tokens)?
|
||||
}
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
},
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
};
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
||||
let generated_text =
|
||||
self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
||||
let max_tokens = self.configuration.max_tokens.generation;
|
||||
let generated_text = match &self.configuration.chat {
|
||||
Some(c) => match &c.generation {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, max_tokens)?
|
||||
}
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
},
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
};
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate_stream(
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_completion_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"chat": {
|
||||
"completion": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_completion_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"config": {
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"chat": {
|
||||
"generation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
10
src/utils.rs
10
src/utils.rs
@@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
|
||||
tokens * 4
|
||||
}
|
||||
|
||||
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
|
||||
pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec<ChatMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| ChatMessage {
|
||||
role: m.role.to_owned(),
|
||||
content: m
|
||||
.content
|
||||
.replace("{context}", &prompt.context)
|
||||
.replace("{code}", &prompt.code),
|
||||
.replace("{CONTEXT}", &prompt.context)
|
||||
.replace("{CODE}", &prompt.code),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn format_context_code(context: &str, code: &str) -> String {
|
||||
format!("{context}\n\n{code}")
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ pub struct DoGenerateStreamResponse {
|
||||
}
|
||||
|
||||
pub struct Worker {
|
||||
transformer_backend: Box<dyn TransformerBackend>,
|
||||
transformer_backend: Box<dyn TransformerBackend + Send>,
|
||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
@@ -79,7 +79,7 @@ pub struct Worker {
|
||||
|
||||
impl Worker {
|
||||
pub fn new(
|
||||
transformer_backend: Box<dyn TransformerBackend>,
|
||||
transformer_backend: Box<dyn TransformerBackend + Send>,
|
||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
@@ -92,8 +92,7 @@ impl Worker {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||
async fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self.memory_backend.lock().build_prompt(
|
||||
&request.params.text_document_position,
|
||||
PromptForType::Completion,
|
||||
@@ -102,7 +101,7 @@ impl Worker {
|
||||
.memory_backend
|
||||
.lock()
|
||||
.get_filter_text(&request.params.text_document_position)?;
|
||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
||||
let response = self.transformer_backend.do_completion(&prompt).await?;
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
@@ -128,7 +127,7 @@ impl Worker {
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
@@ -137,16 +136,16 @@ impl Worker {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||
async fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self.memory_backend.lock().build_prompt(
|
||||
&request.params.text_document_position,
|
||||
PromptForType::Generate,
|
||||
)?;
|
||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
||||
let response = self.transformer_backend.do_generate(&prompt).await?;
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
@@ -154,36 +153,48 @@ impl Worker {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(self) {
|
||||
pub fn run(self) -> anyhow::Result<()> {
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let mut completion_request = self.last_worker_request.lock();
|
||||
std::mem::take(&mut *completion_request)
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => match self.do_completion(&request) {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::Generate(request) => match self.do_generate(&request) {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"),
|
||||
};
|
||||
self.connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending message");
|
||||
runtime.spawn(async move {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => {
|
||||
match self.do_completion(&request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
}
|
||||
}
|
||||
WorkerRequest::Generate(request) => {
|
||||
match self.do_generate(&request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
}
|
||||
}
|
||||
WorkerRequest::GenerateStream(_) => {
|
||||
panic!("Streaming is not supported yet")
|
||||
}
|
||||
};
|
||||
self.connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending message");
|
||||
});
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user