mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 23:14:28 +01:00
Checkpoint
This commit is contained in:
93
Cargo.lock
generated
93
Cargo.lock
generated
@@ -110,9 +110,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyhow"
|
name = "anyhow"
|
||||||
version = "1.0.80"
|
version = "1.0.81"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
|
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "assert_cmd"
|
name = "assert_cmd"
|
||||||
@@ -131,9 +131,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.77"
|
version = "0.1.78"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
|
checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@@ -149,16 +149,6 @@ dependencies = [
|
|||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "atomic-write-file"
|
|
||||||
version = "0.1.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a8204db279bf648d64fe845bd8840f78b39c8132ed4d6a4194c3b10d4b4cfb0b"
|
|
||||||
dependencies = [
|
|
||||||
"nix",
|
|
||||||
"rand",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -532,6 +522,16 @@ dependencies = [
|
|||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ctrlc"
|
||||||
|
version = "3.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345"
|
||||||
|
dependencies = [
|
||||||
|
"nix",
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.14.4"
|
version = "0.14.4"
|
||||||
@@ -1410,6 +1410,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cpp-2"
|
name = "llama-cpp-2"
|
||||||
version = "0.1.34"
|
version = "0.1.34"
|
||||||
|
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"llama-cpp-sys-2",
|
"llama-cpp-sys-2",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
@@ -1419,6 +1420,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cpp-sys-2"
|
name = "llama-cpp-sys-2"
|
||||||
version = "0.1.34"
|
version = "0.1.34"
|
||||||
|
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen",
|
"bindgen",
|
||||||
"cc",
|
"cc",
|
||||||
@@ -1465,6 +1467,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assert_cmd",
|
"assert_cmd",
|
||||||
|
"async-trait",
|
||||||
"directories",
|
"directories",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"ignore",
|
"ignore",
|
||||||
@@ -1956,6 +1959,7 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"colored",
|
"colored",
|
||||||
|
"ctrlc",
|
||||||
"futures",
|
"futures",
|
||||||
"indicatif",
|
"indicatif",
|
||||||
"inquire",
|
"inquire",
|
||||||
@@ -2075,9 +2079,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.78"
|
version = "1.0.79"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
|
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
@@ -2224,9 +2228,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
version = "0.11.25"
|
version = "0.11.26"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946"
|
checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
"bytes",
|
"bytes",
|
||||||
@@ -2769,9 +2773,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx"
|
name = "sqlx"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf"
|
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"sqlx-core",
|
"sqlx-core",
|
||||||
"sqlx-macros",
|
"sqlx-macros",
|
||||||
@@ -2782,9 +2786,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx-core"
|
name = "sqlx-core"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd"
|
checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"atoi",
|
"atoi",
|
||||||
@@ -2792,7 +2796,6 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"crc",
|
"crc",
|
||||||
"crossbeam-queue",
|
"crossbeam-queue",
|
||||||
"dotenvy",
|
|
||||||
"either",
|
"either",
|
||||||
"event-listener",
|
"event-listener",
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
@@ -2827,9 +2830,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx-macros"
|
name = "sqlx-macros"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5"
|
checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@@ -2840,11 +2843,10 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx-macros-core"
|
name = "sqlx-macros-core"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841"
|
checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atomic-write-file",
|
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
"either",
|
"either",
|
||||||
"heck",
|
"heck",
|
||||||
@@ -2867,9 +2869,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx-mysql"
|
name = "sqlx-mysql"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4"
|
checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atoi",
|
"atoi",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
@@ -2911,9 +2913,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx-postgres"
|
name = "sqlx-postgres"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24"
|
checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atoi",
|
"atoi",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
@@ -2938,7 +2940,6 @@ dependencies = [
|
|||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sha1",
|
|
||||||
"sha2",
|
"sha2",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"sqlx-core",
|
"sqlx-core",
|
||||||
@@ -2952,9 +2953,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlx-sqlite"
|
name = "sqlx-sqlite"
|
||||||
version = "0.7.3"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490"
|
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atoi",
|
"atoi",
|
||||||
"flume",
|
"flume",
|
||||||
@@ -3051,20 +3052,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "system-configuration"
|
name = "system-configuration"
|
||||||
version = "0.6.0"
|
version = "0.5.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42"
|
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.4.2",
|
"bitflags 1.3.2",
|
||||||
"core-foundation",
|
"core-foundation",
|
||||||
"system-configuration-sys",
|
"system-configuration-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "system-configuration-sys"
|
name = "system-configuration-sys"
|
||||||
version = "0.6.0"
|
version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
|
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"core-foundation-sys",
|
"core-foundation-sys",
|
||||||
"libc",
|
"libc",
|
||||||
@@ -3090,18 +3091,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.57"
|
version = "1.0.58"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
|
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.57"
|
version = "1.0.58"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
|
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@@ -3631,9 +3632,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "whoami"
|
name = "whoami"
|
||||||
version = "1.5.0"
|
version = "1.5.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e"
|
checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
"wasite",
|
"wasite",
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ parking_lot = "0.12.1"
|
|||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
directories = "5.0.1"
|
directories = "5.0.1"
|
||||||
# llama-cpp-2 = "0.1.31"
|
# llama-cpp-2 = "0.1.31"
|
||||||
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
|
||||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
@@ -30,6 +30,7 @@ ignore = "0.4.22"
|
|||||||
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
|
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
|
||||||
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
|
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
|
||||||
indexmap = "2.2.5"
|
indexmap = "2.2.5"
|
||||||
|
async-trait = "0.1.78"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|||||||
@@ -129,10 +129,6 @@ const fn openai_top_p_default() -> f32 {
|
|||||||
0.95
|
0.95
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn openai_top_k_default() -> usize {
|
|
||||||
40
|
|
||||||
}
|
|
||||||
|
|
||||||
const fn openai_presence_penalty() -> f32 {
|
const fn openai_presence_penalty() -> f32 {
|
||||||
0.
|
0.
|
||||||
}
|
}
|
||||||
@@ -155,7 +151,9 @@ pub struct OpenAI {
|
|||||||
pub auth_token_env_var_name: Option<String>,
|
pub auth_token_env_var_name: Option<String>,
|
||||||
pub auth_token: Option<String>,
|
pub auth_token: Option<String>,
|
||||||
// The completions endpoint
|
// The completions endpoint
|
||||||
pub completions_endpoint: String,
|
pub completions_endpoint: Option<String>,
|
||||||
|
// The chat endpoint
|
||||||
|
pub chat_endpoint: Option<String>,
|
||||||
// The model name
|
// The model name
|
||||||
pub model: String,
|
pub model: String,
|
||||||
// Fill in the middle support
|
// Fill in the middle support
|
||||||
@@ -168,8 +166,6 @@ pub struct OpenAI {
|
|||||||
// Other available args
|
// Other available args
|
||||||
#[serde(default = "openai_top_p_default")]
|
#[serde(default = "openai_top_p_default")]
|
||||||
pub top_p: f32,
|
pub top_p: f32,
|
||||||
#[serde(default = "openai_top_k_default")]
|
|
||||||
pub top_k: usize,
|
|
||||||
#[serde(default = "openai_presence_penalty")]
|
#[serde(default = "openai_presence_penalty")]
|
||||||
pub presence_penalty: f32,
|
pub presence_penalty: f32,
|
||||||
#[serde(default = "openai_frequency_penalty")]
|
#[serde(default = "openai_frequency_penalty")]
|
||||||
|
|||||||
26
src/main.rs
26
src/main.rs
@@ -52,7 +52,7 @@ fn main() -> Result<()> {
|
|||||||
.init();
|
.init();
|
||||||
|
|
||||||
let (connection, io_threads) = Connection::stdio();
|
let (connection, io_threads) = Connection::stdio();
|
||||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
let server_capabilities = serde_json::to_value(ServerCapabilities {
|
||||||
completion_provider: Some(CompletionOptions::default()),
|
completion_provider: Some(CompletionOptions::default()),
|
||||||
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
|
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
|
||||||
TextDocumentSyncKind::INCREMENTAL,
|
TextDocumentSyncKind::INCREMENTAL,
|
||||||
@@ -77,7 +77,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
|
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
|
||||||
|
|
||||||
// Set the memory_backend
|
// Set the memory_backend
|
||||||
let memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>> =
|
let memory_backend: Box<dyn MemoryBackend + Send> =
|
||||||
Arc::new(Mutex::new(args.clone().try_into()?));
|
Arc::new(Mutex::new(args.clone().try_into()?));
|
||||||
|
|
||||||
// Wrap the connection for sharing between threads
|
// Wrap the connection for sharing between threads
|
||||||
@@ -87,6 +87,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
let last_worker_request = Arc::new(Mutex::new(None));
|
let last_worker_request = Arc::new(Mutex::new(None));
|
||||||
|
|
||||||
// Thread local variables
|
// Thread local variables
|
||||||
|
// TODO: Setup some kind of handler for errors here
|
||||||
let thread_memory_backend = memory_backend.clone();
|
let thread_memory_backend = memory_backend.clone();
|
||||||
let thread_last_worker_request = last_worker_request.clone();
|
let thread_last_worker_request = last_worker_request.clone();
|
||||||
let thread_connection = connection.clone();
|
let thread_connection = connection.clone();
|
||||||
@@ -97,7 +98,8 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
thread_last_worker_request,
|
thread_last_worker_request,
|
||||||
thread_connection,
|
thread_connection,
|
||||||
)
|
)
|
||||||
.run();
|
.run()
|
||||||
|
.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
for msg in &connection.receiver {
|
for msg in &connection.receiver {
|
||||||
@@ -143,13 +145,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
Message::Notification(not) => {
|
Message::Notification(not) => {
|
||||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||||
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||||
memory_backend.lock().opened_text_document(params)?;
|
// memory_backend.lock().opened_text_document(params)?;
|
||||||
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||||
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||||
memory_backend.lock().changed_text_document(params)?;
|
// memory_backend.lock().changed_text_document(params)?;
|
||||||
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||||
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||||
memory_backend.lock().renamed_file(params)?;
|
// memory_backend.lock().renamed_file(params)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => (),
|
||||||
@@ -170,18 +172,18 @@ mod tests {
|
|||||||
//////////////////////////////////////
|
//////////////////////////////////////
|
||||||
//////////////////////////////////////
|
//////////////////////////////////////
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn completion_with_default_arguments() {
|
async fn completion_with_default_arguments() {
|
||||||
let args = json!({});
|
let args = json!({});
|
||||||
let configuration = Configuration::new(args).unwrap();
|
let configuration = Configuration::new(args).unwrap();
|
||||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||||
let response = backend.do_completion(&prompt).unwrap();
|
let response = backend.do_completion(&prompt).await.unwrap();
|
||||||
assert!(!response.insert_text.is_empty())
|
assert!(!response.insert_text.is_empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn completion_with_custom_gguf_model() {
|
async fn completion_with_custom_gguf_model() {
|
||||||
let args = json!({
|
let args = json!({
|
||||||
"initializationOptions": {
|
"initializationOptions": {
|
||||||
"memory": {
|
"memory": {
|
||||||
@@ -232,7 +234,7 @@ mod tests {
|
|||||||
let configuration = Configuration::new(args).unwrap();
|
let configuration = Configuration::new(args).unwrap();
|
||||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||||
let response = backend.do_completion(&prompt).unwrap();
|
let response = backend.do_completion(&prompt).await.unwrap();
|
||||||
assert!(!response.insert_text.is_empty());
|
assert!(!response.insert_text.is_empty());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ impl FileStore {
|
|||||||
.iter()
|
.iter()
|
||||||
.filter(|f| **f != current_document_uri)
|
.filter(|f| **f != current_document_uri)
|
||||||
{
|
{
|
||||||
let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0);
|
let needed = characters.saturating_sub(rope.len_chars());
|
||||||
if needed == 0 {
|
if needed == 0 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -99,7 +99,7 @@ impl FileStore {
|
|||||||
.clone();
|
.clone();
|
||||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||||
+ position.position.character as usize;
|
+ position.position.character as usize;
|
||||||
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
|
let start = cursor_index.saturating_sub(characters / 2);
|
||||||
let end = rope
|
let end = rope
|
||||||
.len_chars()
|
.len_chars()
|
||||||
.min(cursor_index + (characters - (cursor_index - start)));
|
.min(cursor_index + (characters - (cursor_index - start)));
|
||||||
@@ -137,15 +137,15 @@ impl FileStore {
|
|||||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||||
{
|
{
|
||||||
let max_length = tokens_to_estimated_characters(max_context_length);
|
let max_length = tokens_to_estimated_characters(max_context_length);
|
||||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
let start = cursor_index.saturating_sub(max_length / 2);
|
||||||
let end = rope
|
let end = rope
|
||||||
.len_chars()
|
.len_chars()
|
||||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||||
|
|
||||||
if is_chat_enabled {
|
if is_chat_enabled {
|
||||||
rope.insert(cursor_index, "{CURSOR}");
|
rope.insert(cursor_index, "<CURSOR>");
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
.get_slice(start..end + "{CURSOR}".chars().count())
|
.get_slice(start..end + "<CURSOR>".chars().count())
|
||||||
.context("Error getting rope slice")?;
|
.context("Error getting rope slice")?;
|
||||||
rope_slice.to_string()
|
rope_slice.to_string()
|
||||||
} else {
|
} else {
|
||||||
@@ -166,9 +166,8 @@ impl FileStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let start = cursor_index
|
let start =
|
||||||
.checked_sub(tokens_to_estimated_characters(max_context_length))
|
cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length));
|
||||||
.unwrap_or(0);
|
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
.get_slice(start..cursor_index)
|
.get_slice(start..cursor_index)
|
||||||
.context("Error getting rope slice")?;
|
.context("Error getting rope slice")?;
|
||||||
@@ -178,9 +177,13 @@ impl FileStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl MemoryBackend for FileStore {
|
impl MemoryBackend for FileStore {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
async fn get_filter_text(
|
||||||
|
&self,
|
||||||
|
position: &TextDocumentPositionParams,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
let rope = self
|
let rope = self
|
||||||
.file_map
|
.file_map
|
||||||
.get(position.text_document.uri.as_str())
|
.get(position.text_document.uri.as_str())
|
||||||
@@ -193,7 +196,7 @@ impl MemoryBackend for FileStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn build_prompt(
|
async fn build_prompt(
|
||||||
&mut self,
|
&mut self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_for_type: PromptForType,
|
prompt_for_type: PromptForType,
|
||||||
@@ -207,7 +210,7 @@ impl MemoryBackend for FileStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn opened_text_document(
|
async fn opened_text_document(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: lsp_types::DidOpenTextDocumentParams,
|
params: lsp_types::DidOpenTextDocumentParams,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
@@ -219,7 +222,7 @@ impl MemoryBackend for FileStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn changed_text_document(
|
async fn changed_text_document(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: lsp_types::DidChangeTextDocumentParams,
|
params: lsp_types::DidChangeTextDocumentParams,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
@@ -246,7 +249,7 @@ impl MemoryBackend for FileStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||||
for file_rename in params.files {
|
for file_rename in params.files {
|
||||||
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
||||||
self.file_map.insert(file_rename.new_uri, rope);
|
self.file_map.insert(file_rename.new_uri, rope);
|
||||||
|
|||||||
@@ -26,19 +26,29 @@ pub enum PromptForType {
|
|||||||
Generate,
|
Generate,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
pub trait MemoryBackend {
|
pub trait MemoryBackend {
|
||||||
fn init(&self) -> anyhow::Result<()> {
|
async fn init(&self) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
async fn opened_text_document(
|
||||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
&mut self,
|
||||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
params: DidOpenTextDocumentParams,
|
||||||
fn build_prompt(
|
) -> anyhow::Result<()>;
|
||||||
|
async fn changed_text_document(
|
||||||
|
&mut self,
|
||||||
|
params: DidChangeTextDocumentParams,
|
||||||
|
) -> anyhow::Result<()>;
|
||||||
|
async fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||||
|
async fn build_prompt(
|
||||||
&mut self,
|
&mut self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_for_type: PromptForType,
|
prompt_for_type: PromptForType,
|
||||||
) -> anyhow::Result<Prompt>;
|
) -> anyhow::Result<Prompt>;
|
||||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
async fn get_filter_text(
|
||||||
|
&self,
|
||||||
|
position: &TextDocumentPositionParams,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||||
@@ -55,3 +65,15 @@ impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This makes testing much easier. Every transformer backend takes in a prompt. When verifying they work, its
|
||||||
|
// easier to just pass in a default prompt.
|
||||||
|
#[cfg(test)]
|
||||||
|
impl Prompt {
|
||||||
|
pub fn default_with_cursor() -> Self {
|
||||||
|
Self {
|
||||||
|
context: r#"def test_context():\n pass"#.to_string(),
|
||||||
|
code: r#"def test_code():\n <CURSOR>"#.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ impl PostgresML {
|
|||||||
};
|
};
|
||||||
// TODO: Think on the naming of the collection
|
// TODO: Think on the naming of the collection
|
||||||
// Maybe filter on metadata or I'm not sure
|
// Maybe filter on metadata or I'm not sure
|
||||||
let collection = Collection::new("test-lsp-ai-2", Some(database_url))?;
|
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
|
||||||
// TODO: Review the pipeline
|
// TODO: Review the pipeline
|
||||||
let pipeline = Pipeline::new(
|
let pipeline = Pipeline::new(
|
||||||
"v1",
|
"v1",
|
||||||
@@ -50,7 +50,7 @@ impl PostgresML {
|
|||||||
"splitter": {
|
"splitter": {
|
||||||
"model": "recursive_character",
|
"model": "recursive_character",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"chunk_size": 512,
|
"chunk_size": 1500,
|
||||||
"chunk_overlap": 40
|
"chunk_overlap": 40
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -90,7 +90,7 @@ impl PostgresML {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|path| {
|
.map(|path| {
|
||||||
let text = std::fs::read_to_string(&path)
|
let text = std::fs::read_to_string(&path)
|
||||||
.expect(format!("Error reading path: {}", path).as_str());
|
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
|
||||||
json!({
|
json!({
|
||||||
"id": path,
|
"id": path,
|
||||||
"text": text
|
"text": text
|
||||||
@@ -121,24 +121,28 @@ impl PostgresML {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl MemoryBackend for PostgresML {
|
impl MemoryBackend for PostgresML {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
async fn get_filter_text(
|
||||||
self.file_store.get_filter_text(position)
|
&self,
|
||||||
|
position: &TextDocumentPositionParams,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
self.file_store.get_filter_text(position).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn build_prompt(
|
async fn build_prompt(
|
||||||
&mut self,
|
&mut self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_for_type: PromptForType,
|
prompt_for_type: PromptForType,
|
||||||
) -> anyhow::Result<Prompt> {
|
) -> anyhow::Result<Prompt> {
|
||||||
// This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform
|
|
||||||
let query = self
|
let query = self
|
||||||
.file_store
|
.file_store
|
||||||
.get_characters_around_position(position, 512)?;
|
.get_characters_around_position(position, 512)?;
|
||||||
let res = self.runtime.block_on(
|
let res = self
|
||||||
self.collection.vector_search(
|
.collection
|
||||||
|
.vector_search(
|
||||||
json!({
|
json!({
|
||||||
"query": {
|
"query": {
|
||||||
"fields": {
|
"fields": {
|
||||||
@@ -151,8 +155,8 @@ impl MemoryBackend for PostgresML {
|
|||||||
})
|
})
|
||||||
.into(),
|
.into(),
|
||||||
&mut self.pipeline,
|
&mut self.pipeline,
|
||||||
),
|
)
|
||||||
)?;
|
.await?;
|
||||||
let context = res
|
let context = res
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|c| {
|
.map(|c| {
|
||||||
@@ -176,7 +180,7 @@ impl MemoryBackend for PostgresML {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn opened_text_document(
|
async fn opened_text_document(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: lsp_types::DidOpenTextDocumentParams,
|
params: lsp_types::DidOpenTextDocumentParams,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
@@ -185,68 +189,63 @@ impl MemoryBackend for PostgresML {
|
|||||||
let task_added_pipeline = self.added_pipeline;
|
let task_added_pipeline = self.added_pipeline;
|
||||||
let mut task_collection = self.collection.clone();
|
let mut task_collection = self.collection.clone();
|
||||||
let mut task_pipeline = self.pipeline.clone();
|
let mut task_pipeline = self.pipeline.clone();
|
||||||
self.runtime.spawn(async move {
|
if !task_added_pipeline {
|
||||||
if !task_added_pipeline {
|
task_collection
|
||||||
task_collection
|
.add_pipeline(&mut task_pipeline)
|
||||||
.add_pipeline(&mut task_pipeline)
|
.await
|
||||||
.await
|
.expect("PGML - Error adding pipeline to collection");
|
||||||
.expect("PGML - Error adding pipeline to collection");
|
}
|
||||||
}
|
task_collection
|
||||||
|
.upsert_documents(
|
||||||
|
vec![json!({
|
||||||
|
"id": path,
|
||||||
|
"text": text
|
||||||
|
})
|
||||||
|
.into()],
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("PGML - Error upserting documents");
|
||||||
|
self.file_store.opened_text_document(params).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
async fn changed_text_document(
|
||||||
|
&mut self,
|
||||||
|
params: lsp_types::DidChangeTextDocumentParams,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let path = params.text_document.uri.path().to_owned();
|
||||||
|
self.debounce_tx.send(path)?;
|
||||||
|
self.file_store.changed_text_document(params).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
async fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||||
|
let mut task_collection = self.collection.clone();
|
||||||
|
let task_params = params.clone();
|
||||||
|
for file in task_params.files {
|
||||||
|
task_collection
|
||||||
|
.delete_documents(
|
||||||
|
json!({
|
||||||
|
"id": file.old_uri
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("PGML - Error deleting file");
|
||||||
|
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||||
task_collection
|
task_collection
|
||||||
.upsert_documents(
|
.upsert_documents(
|
||||||
vec![json!({
|
vec![json!({
|
||||||
"id": path,
|
"id": file.new_uri,
|
||||||
"text": text
|
"text": text
|
||||||
})
|
})
|
||||||
.into()],
|
.into()],
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("PGML - Error upserting documents");
|
.expect("PGML - Error adding pipeline to collection");
|
||||||
});
|
}
|
||||||
self.file_store.opened_text_document(params)
|
self.file_store.renamed_file(params).await
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
fn changed_text_document(
|
|
||||||
&mut self,
|
|
||||||
params: lsp_types::DidChangeTextDocumentParams,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let path = params.text_document.uri.path().to_owned();
|
|
||||||
self.debounce_tx.send(path)?;
|
|
||||||
self.file_store.changed_text_document(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
|
||||||
let mut task_collection = self.collection.clone();
|
|
||||||
let task_params = params.clone();
|
|
||||||
self.runtime.spawn(async move {
|
|
||||||
for file in task_params.files {
|
|
||||||
task_collection
|
|
||||||
.delete_documents(
|
|
||||||
json!({
|
|
||||||
"id": file.old_uri
|
|
||||||
})
|
|
||||||
.into(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("PGML - Error deleting file");
|
|
||||||
let text =
|
|
||||||
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
|
||||||
task_collection
|
|
||||||
.upsert_documents(
|
|
||||||
vec![json!({
|
|
||||||
"id": file.new_uri,
|
|
||||||
"text": text
|
|
||||||
})
|
|
||||||
.into()],
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("PGML - Error adding pipeline to collection");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
self.file_store.renamed_file(params)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ use anyhow::Context;
|
|||||||
use hf_hub::api::sync::ApiBuilder;
|
use hf_hub::api::sync::ApiBuilder;
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
|
|
||||||
use super::TransformerBackend;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
configuration::{self},
|
configuration::{self},
|
||||||
memory_backends::Prompt,
|
memory_backends::Prompt,
|
||||||
@@ -16,6 +15,8 @@ use crate::{
|
|||||||
mod model;
|
mod model;
|
||||||
use model::Model;
|
use model::Model;
|
||||||
|
|
||||||
|
use super::TransformerBackend;
|
||||||
|
|
||||||
pub struct LlamaCPP {
|
pub struct LlamaCPP {
|
||||||
model: Model,
|
model: Model,
|
||||||
configuration: configuration::ModelGGUF,
|
configuration: configuration::ModelGGUF,
|
||||||
@@ -62,9 +63,10 @@ impl LlamaCPP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl TransformerBackend for LlamaCPP {
|
impl TransformerBackend for LlamaCPP {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||||
// let prompt = self.get_prompt_string(prompt)?;
|
// let prompt = self.get_prompt_string(prompt)?;
|
||||||
let prompt = &prompt.code;
|
let prompt = &prompt.code;
|
||||||
debug!("Prompt string for LLM: {}", prompt);
|
debug!("Prompt string for LLM: {}", prompt);
|
||||||
@@ -75,7 +77,7 @@ impl TransformerBackend for LlamaCPP {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||||
// let prompt = self.get_prompt_string(prompt)?;
|
// let prompt = self.get_prompt_string(prompt)?;
|
||||||
// debug!("Prompt string for LLM: {}", prompt);
|
// debug!("Prompt string for LLM: {}", prompt);
|
||||||
let prompt = &prompt.code;
|
let prompt = &prompt.code;
|
||||||
@@ -86,7 +88,7 @@ impl TransformerBackend for LlamaCPP {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_generate_stream(
|
async fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
_request: &GenerateStreamRequest,
|
_request: &GenerateStreamRequest,
|
||||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ use crate::{
|
|||||||
mod llama_cpp;
|
mod llama_cpp;
|
||||||
mod openai;
|
mod openai;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
pub trait TransformerBackend {
|
pub trait TransformerBackend {
|
||||||
// Should all take an enum of chat messages or just a string for completion
|
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
||||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
||||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
async fn do_generate_stream(
|
||||||
fn do_generate_stream(
|
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateStreamRequest,
|
request: &GenerateStreamRequest,
|
||||||
) -> anyhow::Result<DoGenerateStreamResponse>;
|
) -> anyhow::Result<DoGenerateStreamResponse>;
|
||||||
|
|||||||
@@ -1,16 +1,22 @@
|
|||||||
|
// Something more about what this file is
|
||||||
|
// NOTE: When decoding responses from OpenAI compatbile services, we don't care about every field
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::{json, Value};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
use super::TransformerBackend;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
configuration,
|
configuration::{self, ChatMessage},
|
||||||
memory_backends::Prompt,
|
memory_backends::Prompt,
|
||||||
|
utils::{format_chat_messages, format_context_code},
|
||||||
worker::{
|
worker::{
|
||||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::TransformerBackend;
|
||||||
|
|
||||||
pub struct OpenAI {
|
pub struct OpenAI {
|
||||||
configuration: configuration::OpenAI,
|
configuration: configuration::OpenAI,
|
||||||
}
|
}
|
||||||
@@ -22,7 +28,19 @@ struct OpenAICompletionsChoice {
|
|||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OpenAICompletionsResponse {
|
struct OpenAICompletionsResponse {
|
||||||
choices: Vec<OpenAICompletionsChoice>,
|
choices: Option<Vec<OpenAICompletionsChoice>>,
|
||||||
|
error: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIChatChoices {
|
||||||
|
message: ChatMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIChatResponse {
|
||||||
|
choices: Option<Vec<OpenAIChatChoices>>,
|
||||||
|
error: Option<Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAI {
|
impl OpenAI {
|
||||||
@@ -42,7 +60,12 @@ impl OpenAI {
|
|||||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||||
};
|
};
|
||||||
let res: OpenAICompletionsResponse = client
|
let res: OpenAICompletionsResponse = client
|
||||||
.post(&self.configuration.completions_endpoint)
|
.post(
|
||||||
|
self.configuration
|
||||||
|
.completions_endpoint
|
||||||
|
.as_ref()
|
||||||
|
.context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?,
|
||||||
|
)
|
||||||
.bearer_auth(token)
|
.bearer_auth(token)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Accept", "application/json")
|
.header("Accept", "application/json")
|
||||||
@@ -51,7 +74,6 @@ impl OpenAI {
|
|||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"n": 1,
|
"n": 1,
|
||||||
"top_p": self.configuration.top_p,
|
"top_p": self.configuration.top_p,
|
||||||
"top_k": self.configuration.top_k,
|
|
||||||
"presence_penalty": self.configuration.presence_penalty,
|
"presence_penalty": self.configuration.presence_penalty,
|
||||||
"frequency_penalty": self.configuration.frequency_penalty,
|
"frequency_penalty": self.configuration.frequency_penalty,
|
||||||
"temperature": self.configuration.temperature,
|
"temperature": self.configuration.temperature,
|
||||||
@@ -60,34 +82,219 @@ impl OpenAI {
|
|||||||
}))
|
}))
|
||||||
.send()?
|
.send()?
|
||||||
.json()?;
|
.json()?;
|
||||||
eprintln!("**********RECEIVED REQUEST********");
|
if let Some(error) = res.error {
|
||||||
Ok(res.choices[0].text.clone())
|
anyhow::bail!("{:?}", error.to_string())
|
||||||
|
} else if let Some(choices) = res.choices {
|
||||||
|
Ok(choices[0].text.clone())
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_chat(&self, messages: Vec<ChatMessage>, max_tokens: usize) -> anyhow::Result<String> {
|
||||||
|
eprintln!(
|
||||||
|
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||||
|
messages
|
||||||
|
);
|
||||||
|
let client = reqwest::blocking::Client::new();
|
||||||
|
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||||
|
std::env::var(env_var_name)?
|
||||||
|
} else if let Some(token) = &self.configuration.auth_token {
|
||||||
|
token.to_string()
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||||
|
};
|
||||||
|
let res: OpenAIChatResponse = client
|
||||||
|
.post(
|
||||||
|
self.configuration
|
||||||
|
.chat_endpoint
|
||||||
|
.as_ref()
|
||||||
|
.context("must specify `completions_endpoint` to use completions")?,
|
||||||
|
)
|
||||||
|
.bearer_auth(token)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.json(&json!({
|
||||||
|
"model": self.configuration.model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"n": 1,
|
||||||
|
"top_p": self.configuration.top_p,
|
||||||
|
"presence_penalty": self.configuration.presence_penalty,
|
||||||
|
"frequency_penalty": self.configuration.frequency_penalty,
|
||||||
|
"temperature": self.configuration.temperature,
|
||||||
|
"messages": messages
|
||||||
|
}))
|
||||||
|
.send()?
|
||||||
|
.json()?;
|
||||||
|
if let Some(error) = res.error {
|
||||||
|
anyhow::bail!("{:?}", error.to_string())
|
||||||
|
} else if let Some(choices) = res.choices {
|
||||||
|
Ok(choices[0].message.content.clone())
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl TransformerBackend for OpenAI {
|
impl TransformerBackend for OpenAI {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||||
eprintln!("--------------{:?}---------------", prompt);
|
eprintln!("--------------{:?}---------------", prompt);
|
||||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
let max_tokens = self.configuration.max_tokens.completion;
|
||||||
let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
let insert_text = match &self.configuration.chat {
|
||||||
|
Some(c) => match &c.completion {
|
||||||
|
Some(completion_messages) => {
|
||||||
|
let messages = format_chat_messages(completion_messages, prompt);
|
||||||
|
self.get_chat(messages, max_tokens)?
|
||||||
|
}
|
||||||
|
None => self.get_completion(
|
||||||
|
&format_context_code(&prompt.context, &prompt.code),
|
||||||
|
max_tokens,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
None => self.get_completion(
|
||||||
|
&format_context_code(&prompt.context, &prompt.code),
|
||||||
|
max_tokens,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
Ok(DoCompletionResponse { insert_text })
|
Ok(DoCompletionResponse { insert_text })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||||
eprintln!("--------------{:?}---------------", prompt);
|
eprintln!("--------------{:?}---------------", prompt);
|
||||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
let max_tokens = self.configuration.max_tokens.generation;
|
||||||
let generated_text =
|
let generated_text = match &self.configuration.chat {
|
||||||
self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
Some(c) => match &c.generation {
|
||||||
|
Some(completion_messages) => {
|
||||||
|
let messages = format_chat_messages(completion_messages, prompt);
|
||||||
|
self.get_chat(messages, max_tokens)?
|
||||||
|
}
|
||||||
|
None => self.get_completion(
|
||||||
|
&format_context_code(&prompt.context, &prompt.code),
|
||||||
|
max_tokens,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
None => self.get_completion(
|
||||||
|
&format_context_code(&prompt.context, &prompt.code),
|
||||||
|
max_tokens,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
Ok(DoGenerateResponse { generated_text })
|
Ok(DoGenerateResponse { generated_text })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_generate_stream(
|
async fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateStreamRequest,
|
request: &GenerateStreamRequest,
|
||||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_completion_do_completion() -> anyhow::Result<()> {
|
||||||
|
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||||
|
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||||
|
"max_tokens": {
|
||||||
|
"completion": 16,
|
||||||
|
"generation": 64
|
||||||
|
},
|
||||||
|
"max_context": 4096
|
||||||
|
}))?;
|
||||||
|
let openai = OpenAI::new(configuration);
|
||||||
|
let prompt = Prompt::default_with_cursor();
|
||||||
|
let response = openai.do_completion(&prompt).await?;
|
||||||
|
assert!(!response.insert_text.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_chat_do_completion() -> anyhow::Result<()> {
|
||||||
|
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||||
|
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||||
|
"chat": {
|
||||||
|
"completion": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"max_tokens": {
|
||||||
|
"completion": 16,
|
||||||
|
"generation": 64
|
||||||
|
},
|
||||||
|
"max_context": 4096
|
||||||
|
}))?;
|
||||||
|
let openai = OpenAI::new(configuration);
|
||||||
|
let prompt = Prompt::default_with_cursor();
|
||||||
|
let response = openai.do_completion(&prompt).await?;
|
||||||
|
assert!(!response.insert_text.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_completion_do_generate() -> anyhow::Result<()> {
|
||||||
|
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||||
|
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||||
|
"max_tokens": {
|
||||||
|
"completion": 16,
|
||||||
|
"generation": 64
|
||||||
|
},
|
||||||
|
"max_context": 4096
|
||||||
|
}))?;
|
||||||
|
let openai = OpenAI::new(configuration);
|
||||||
|
let prompt = Prompt::default_with_cursor();
|
||||||
|
let response = openai.do_generate(&prompt).await?;
|
||||||
|
assert!(!response.generated_text.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_chat_do_generate() -> anyhow::Result<()> {
|
||||||
|
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||||
|
"config": {
|
||||||
|
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||||
|
"chat": {
|
||||||
|
"generation": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"max_tokens": {
|
||||||
|
"completion": 16,
|
||||||
|
"generation": 64
|
||||||
|
},
|
||||||
|
"max_context": 4096
|
||||||
|
}}))?;
|
||||||
|
let openai = OpenAI::new(configuration);
|
||||||
|
let prompt = Prompt::default_with_cursor();
|
||||||
|
let response = openai.do_generate(&prompt).await?;
|
||||||
|
assert!(!response.generated_text.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
10
src/utils.rs
10
src/utils.rs
@@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
|
|||||||
tokens * 4
|
tokens * 4
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
|
pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec<ChatMessage> {
|
||||||
messages
|
messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| ChatMessage {
|
.map(|m| ChatMessage {
|
||||||
role: m.role.to_owned(),
|
role: m.role.to_owned(),
|
||||||
content: m
|
content: m
|
||||||
.content
|
.content
|
||||||
.replace("{context}", &prompt.context)
|
.replace("{CONTEXT}", &prompt.context)
|
||||||
.replace("{code}", &prompt.code),
|
.replace("{CODE}", &prompt.code),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn format_context_code(context: &str, code: &str) -> String {
|
||||||
|
format!("{context}\n\n{code}")
|
||||||
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ pub struct DoGenerateStreamResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct Worker {
|
pub struct Worker {
|
||||||
transformer_backend: Box<dyn TransformerBackend>,
|
transformer_backend: Box<dyn TransformerBackend + Send>,
|
||||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||||
connection: Arc<Connection>,
|
connection: Arc<Connection>,
|
||||||
@@ -79,7 +79,7 @@ pub struct Worker {
|
|||||||
|
|
||||||
impl Worker {
|
impl Worker {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
transformer_backend: Box<dyn TransformerBackend>,
|
transformer_backend: Box<dyn TransformerBackend + Send>,
|
||||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||||
connection: Arc<Connection>,
|
connection: Arc<Connection>,
|
||||||
@@ -92,8 +92,7 @@ impl Worker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
async fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||||
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
|
||||||
let prompt = self.memory_backend.lock().build_prompt(
|
let prompt = self.memory_backend.lock().build_prompt(
|
||||||
&request.params.text_document_position,
|
&request.params.text_document_position,
|
||||||
PromptForType::Completion,
|
PromptForType::Completion,
|
||||||
@@ -102,7 +101,7 @@ impl Worker {
|
|||||||
.memory_backend
|
.memory_backend
|
||||||
.lock()
|
.lock()
|
||||||
.get_filter_text(&request.params.text_document_position)?;
|
.get_filter_text(&request.params.text_document_position)?;
|
||||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
let response = self.transformer_backend.do_completion(&prompt).await?;
|
||||||
let completion_text_edit = TextEdit::new(
|
let completion_text_edit = TextEdit::new(
|
||||||
Range::new(
|
Range::new(
|
||||||
Position::new(
|
Position::new(
|
||||||
@@ -128,7 +127,7 @@ impl Worker {
|
|||||||
items: vec![item],
|
items: vec![item],
|
||||||
};
|
};
|
||||||
let result = Some(CompletionResponse::List(completion_list));
|
let result = Some(CompletionResponse::List(completion_list));
|
||||||
let result = serde_json::to_value(&result).unwrap();
|
let result = serde_json::to_value(result).unwrap();
|
||||||
Ok(Response {
|
Ok(Response {
|
||||||
id: request.id.clone(),
|
id: request.id.clone(),
|
||||||
result: Some(result),
|
result: Some(result),
|
||||||
@@ -137,16 +136,16 @@ impl Worker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
async fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||||
let prompt = self.memory_backend.lock().build_prompt(
|
let prompt = self.memory_backend.lock().build_prompt(
|
||||||
&request.params.text_document_position,
|
&request.params.text_document_position,
|
||||||
PromptForType::Generate,
|
PromptForType::Generate,
|
||||||
)?;
|
)?;
|
||||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
let response = self.transformer_backend.do_generate(&prompt).await?;
|
||||||
let result = GenerateResult {
|
let result = GenerateResult {
|
||||||
generated_text: response.generated_text,
|
generated_text: response.generated_text,
|
||||||
};
|
};
|
||||||
let result = serde_json::to_value(&result).unwrap();
|
let result = serde_json::to_value(result).unwrap();
|
||||||
Ok(Response {
|
Ok(Response {
|
||||||
id: request.id.clone(),
|
id: request.id.clone(),
|
||||||
result: Some(result),
|
result: Some(result),
|
||||||
@@ -154,36 +153,48 @@ impl Worker {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(self) {
|
pub fn run(self) -> anyhow::Result<()> {
|
||||||
|
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.worker_threads(4)
|
||||||
|
.enable_all()
|
||||||
|
.build()?;
|
||||||
loop {
|
loop {
|
||||||
let option_worker_request: Option<WorkerRequest> = {
|
let option_worker_request: Option<WorkerRequest> = {
|
||||||
let mut completion_request = self.last_worker_request.lock();
|
let mut completion_request = self.last_worker_request.lock();
|
||||||
std::mem::take(&mut *completion_request)
|
std::mem::take(&mut *completion_request)
|
||||||
};
|
};
|
||||||
if let Some(request) = option_worker_request {
|
if let Some(request) = option_worker_request {
|
||||||
let response = match request {
|
runtime.spawn(async move {
|
||||||
WorkerRequest::Completion(request) => match self.do_completion(&request) {
|
let response = match request {
|
||||||
Ok(r) => r,
|
WorkerRequest::Completion(request) => {
|
||||||
Err(e) => Response {
|
match self.do_completion(&request).await {
|
||||||
id: request.id,
|
Ok(r) => r,
|
||||||
result: None,
|
Err(e) => Response {
|
||||||
error: Some(e.to_response_error(-32603)),
|
id: request.id,
|
||||||
},
|
result: None,
|
||||||
},
|
error: Some(e.to_response_error(-32603)),
|
||||||
WorkerRequest::Generate(request) => match self.do_generate(&request) {
|
},
|
||||||
Ok(r) => r,
|
}
|
||||||
Err(e) => Response {
|
}
|
||||||
id: request.id,
|
WorkerRequest::Generate(request) => {
|
||||||
result: None,
|
match self.do_generate(&request).await {
|
||||||
error: Some(e.to_response_error(-32603)),
|
Ok(r) => r,
|
||||||
},
|
Err(e) => Response {
|
||||||
},
|
id: request.id,
|
||||||
WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"),
|
result: None,
|
||||||
};
|
error: Some(e.to_response_error(-32603)),
|
||||||
self.connection
|
},
|
||||||
.sender
|
}
|
||||||
.send(Message::Response(response))
|
}
|
||||||
.expect("Error sending message");
|
WorkerRequest::GenerateStream(_) => {
|
||||||
|
panic!("Streaming is not supported yet")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.connection
|
||||||
|
.sender
|
||||||
|
.send(Message::Response(response))
|
||||||
|
.expect("Error sending message");
|
||||||
|
});
|
||||||
}
|
}
|
||||||
thread::sleep(std::time::Duration::from_millis(5));
|
thread::sleep(std::time::Duration::from_millis(5));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user