mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 15:04:29 +01:00
Merge pull request #1 from SilasMarvin/silas-async-overhaul
Silas async overhaul
This commit is contained in:
94
Cargo.lock
generated
94
Cargo.lock
generated
@@ -110,9 +110,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.80"
|
||||
version = "1.0.81"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
|
||||
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
||||
|
||||
[[package]]
|
||||
name = "assert_cmd"
|
||||
@@ -131,9 +131,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.77"
|
||||
version = "0.1.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
|
||||
checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -149,16 +149,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-write-file"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8204db279bf648d64fe845bd8840f78b39c8132ed4d6a4194c3b10d4b4cfb0b"
|
||||
dependencies = [
|
||||
"nix",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
@@ -532,6 +522,16 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctrlc"
|
||||
version = "3.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345"
|
||||
dependencies = [
|
||||
"nix",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.14.4"
|
||||
@@ -1410,6 +1410,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
[[package]]
|
||||
name = "llama-cpp-2"
|
||||
version = "0.1.34"
|
||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||
dependencies = [
|
||||
"llama-cpp-sys-2",
|
||||
"thiserror",
|
||||
@@ -1419,6 +1420,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "llama-cpp-sys-2"
|
||||
version = "0.1.34"
|
||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"cc",
|
||||
@@ -1465,6 +1467,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"async-trait",
|
||||
"directories",
|
||||
"hf-hub",
|
||||
"ignore",
|
||||
@@ -1956,6 +1959,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"colored",
|
||||
"ctrlc",
|
||||
"futures",
|
||||
"indicatif",
|
||||
"inquire",
|
||||
@@ -1963,6 +1967,7 @@ dependencies = [
|
||||
"itertools 0.10.5",
|
||||
"lopdf",
|
||||
"md5",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"regex",
|
||||
"reqwest",
|
||||
@@ -2075,9 +2080,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.78"
|
||||
version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
|
||||
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
@@ -2224,9 +2229,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.25"
|
||||
version = "0.11.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946"
|
||||
checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
@@ -2769,9 +2774,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf"
|
||||
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
|
||||
dependencies = [
|
||||
"sqlx-core",
|
||||
"sqlx-macros",
|
||||
@@ -2782,9 +2787,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-core"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd"
|
||||
checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"atoi",
|
||||
@@ -2792,7 +2797,6 @@ dependencies = [
|
||||
"bytes",
|
||||
"crc",
|
||||
"crossbeam-queue",
|
||||
"dotenvy",
|
||||
"either",
|
||||
"event-listener",
|
||||
"futures-channel",
|
||||
@@ -2827,9 +2831,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5"
|
||||
checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2840,11 +2844,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros-core"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841"
|
||||
checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
|
||||
dependencies = [
|
||||
"atomic-write-file",
|
||||
"dotenvy",
|
||||
"either",
|
||||
"heck",
|
||||
@@ -2867,9 +2870,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-mysql"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4"
|
||||
checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.21.7",
|
||||
@@ -2911,9 +2914,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-postgres"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24"
|
||||
checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.21.7",
|
||||
@@ -2938,7 +2941,6 @@ dependencies = [
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core",
|
||||
@@ -2952,9 +2954,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-sqlite"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490"
|
||||
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"flume",
|
||||
@@ -3051,20 +3053,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.6.0"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
dependencies = [
|
||||
"bitflags 2.4.2",
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.6.0"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
|
||||
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
@@ -3090,18 +3092,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.57"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
|
||||
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.57"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
|
||||
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -3631,9 +3633,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.5.0"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e"
|
||||
checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9"
|
||||
dependencies = [
|
||||
"redox_syscall",
|
||||
"wasite",
|
||||
|
||||
@@ -20,7 +20,7 @@ parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
# llama-cpp-2 = "0.1.31"
|
||||
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
@@ -30,6 +30,7 @@ ignore = "0.4.22"
|
||||
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
|
||||
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
|
||||
indexmap = "2.2.5"
|
||||
async-trait = "0.1.78"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -19,6 +19,7 @@ pub enum ValidMemoryBackend {
|
||||
pub enum ValidTransformerBackend {
|
||||
LlamaCPP(ModelGGUF),
|
||||
OpenAI(OpenAI),
|
||||
Anthropic(Anthropic),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -36,6 +37,7 @@ pub struct Chat {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct FIM {
|
||||
pub start: String,
|
||||
pub middle: String,
|
||||
@@ -129,10 +131,6 @@ const fn openai_top_p_default() -> f32 {
|
||||
0.95
|
||||
}
|
||||
|
||||
const fn openai_top_k_default() -> usize {
|
||||
40
|
||||
}
|
||||
|
||||
const fn openai_presence_penalty() -> f32 {
|
||||
0.
|
||||
}
|
||||
@@ -149,13 +147,15 @@ const fn openai_max_context() -> usize {
|
||||
DEFAULT_OPENAI_MAX_CONTEXT
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct OpenAI {
|
||||
// The auth token env var name
|
||||
pub auth_token_env_var_name: Option<String>,
|
||||
pub auth_token: Option<String>,
|
||||
// The completions endpoint
|
||||
pub completions_endpoint: String,
|
||||
pub completions_endpoint: Option<String>,
|
||||
// The chat endpoint
|
||||
pub chat_endpoint: Option<String>,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// Fill in the middle support
|
||||
@@ -168,8 +168,6 @@ pub struct OpenAI {
|
||||
// Other available args
|
||||
#[serde(default = "openai_top_p_default")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "openai_top_k_default")]
|
||||
pub top_k: usize,
|
||||
#[serde(default = "openai_presence_penalty")]
|
||||
pub presence_penalty: f32,
|
||||
#[serde(default = "openai_frequency_penalty")]
|
||||
@@ -180,9 +178,35 @@ pub struct OpenAI {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Anthropic {
|
||||
// The auth token env var name
|
||||
pub auth_token_env_var_name: Option<String>,
|
||||
pub auth_token: Option<String>,
|
||||
// The completions endpoint
|
||||
pub completions_endpoint: Option<String>,
|
||||
// The chat endpoint
|
||||
pub chat_endpoint: Option<String>,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// Fill in the middle support
|
||||
pub fim: Option<FIM>,
|
||||
// The maximum number of new tokens to generate
|
||||
#[serde(default)]
|
||||
pub max_tokens: MaxTokens,
|
||||
// Chat args
|
||||
pub chat: Chat,
|
||||
// System prompt
|
||||
#[serde(default = "openai_top_p_default")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "openai_temperature")]
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ValidTransformerConfiguration {
|
||||
openai: Option<OpenAI>,
|
||||
anthropic: Option<Anthropic>,
|
||||
model_gguf: Option<ModelGGUF>,
|
||||
}
|
||||
|
||||
@@ -190,6 +214,7 @@ impl Default for ValidTransformerConfiguration {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_gguf: Some(ModelGGUF::default()),
|
||||
anthropic: None,
|
||||
openai: None,
|
||||
}
|
||||
}
|
||||
|
||||
76
src/main.rs
76
src/main.rs
@@ -6,25 +6,31 @@ use lsp_types::{
|
||||
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
use std::{
|
||||
sync::{mpsc, Arc},
|
||||
thread,
|
||||
};
|
||||
use tracing::error;
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
|
||||
mod configuration;
|
||||
mod custom_requests;
|
||||
mod memory_backends;
|
||||
mod memory_worker;
|
||||
mod template;
|
||||
mod transformer_backends;
|
||||
mod transformer_worker;
|
||||
mod utils;
|
||||
mod worker;
|
||||
|
||||
use configuration::Configuration;
|
||||
use custom_requests::generate::Generate;
|
||||
use memory_backends::MemoryBackend;
|
||||
use transformer_backends::TransformerBackend;
|
||||
use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
|
||||
use transformer_worker::{CompletionRequest, GenerateRequest, WorkerRequest};
|
||||
|
||||
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
|
||||
use crate::{
|
||||
custom_requests::generate_stream::GenerateStream, transformer_worker::GenerateStreamRequest,
|
||||
};
|
||||
|
||||
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
|
||||
notification.method == N::METHOD
|
||||
@@ -52,7 +58,7 @@ fn main() -> Result<()> {
|
||||
.init();
|
||||
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||
let server_capabilities = serde_json::to_value(ServerCapabilities {
|
||||
completion_provider: Some(CompletionOptions::default()),
|
||||
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
|
||||
TextDocumentSyncKind::INCREMENTAL,
|
||||
@@ -66,38 +72,40 @@ fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This main loop is tricky
|
||||
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
|
||||
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
|
||||
// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations
|
||||
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let args = Configuration::new(args)?;
|
||||
|
||||
// Set the transformer_backend
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
|
||||
|
||||
// Set the memory_backend
|
||||
let memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>> =
|
||||
Arc::new(Mutex::new(args.clone().try_into()?));
|
||||
// Build our configuration
|
||||
let configuration = Configuration::new(args)?;
|
||||
|
||||
// Wrap the connection for sharing between threads
|
||||
let connection = Arc::new(connection);
|
||||
|
||||
// How we communicate between the worker and receiver threads
|
||||
// Our channel we use to communicate with our transformer_worker
|
||||
let last_worker_request = Arc::new(Mutex::new(None));
|
||||
|
||||
// Setup our memory_worker
|
||||
// TODO: Setup some kind of error handler
|
||||
// Set the memory_backend
|
||||
// The channel we use to communicate with our memory_worker
|
||||
let (memory_tx, memory_rx) = mpsc::channel();
|
||||
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = configuration.clone().try_into()?;
|
||||
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
|
||||
|
||||
// Setup our transformer_worker
|
||||
// Thread local variables
|
||||
let thread_memory_backend = memory_backend.clone();
|
||||
// TODO: Setup some kind of handler for errors here
|
||||
// Set the transformer_backend
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
configuration.clone().try_into()?;
|
||||
let thread_last_worker_request = last_worker_request.clone();
|
||||
let thread_connection = connection.clone();
|
||||
let thread_memory_tx = memory_tx.clone();
|
||||
thread::spawn(move || {
|
||||
Worker::new(
|
||||
transformer_worker::run(
|
||||
transformer_backend,
|
||||
thread_memory_backend,
|
||||
thread_memory_tx,
|
||||
thread_last_worker_request,
|
||||
thread_connection,
|
||||
)
|
||||
.run();
|
||||
});
|
||||
|
||||
for msg in &connection.receiver {
|
||||
@@ -143,13 +151,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
Message::Notification(not) => {
|
||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
memory_backend.lock().opened_text_document(params)?;
|
||||
memory_tx.send(memory_worker::WorkerRequest::DidOpenTextDocument(params))?;
|
||||
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
memory_backend.lock().changed_text_document(params)?;
|
||||
memory_tx.send(memory_worker::WorkerRequest::DidChangeTextDocument(params))?;
|
||||
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||
memory_backend.lock().renamed_file(params)?;
|
||||
memory_tx.send(memory_worker::WorkerRequest::DidRenameFiles(params))?;
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
@@ -170,18 +178,19 @@ mod tests {
|
||||
//////////////////////////////////////
|
||||
//////////////////////////////////////
|
||||
|
||||
#[test]
|
||||
fn completion_with_default_arguments() {
|
||||
#[tokio::test]
|
||||
async fn completion_with_default_arguments() {
|
||||
let args = json!({});
|
||||
let configuration = Configuration::new(args).unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
configuration.clone().try_into().unwrap();
|
||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||
let response = backend.do_completion(&prompt).unwrap();
|
||||
let response = backend.do_completion(&prompt).await.unwrap();
|
||||
assert!(!response.insert_text.is_empty())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn completion_with_custom_gguf_model() {
|
||||
#[tokio::test]
|
||||
async fn completion_with_custom_gguf_model() {
|
||||
let args = json!({
|
||||
"initializationOptions": {
|
||||
"memory": {
|
||||
@@ -230,9 +239,10 @@ mod tests {
|
||||
}
|
||||
});
|
||||
let configuration = Configuration::new(args).unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
|
||||
let backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
configuration.clone().try_into().unwrap();
|
||||
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
|
||||
let response = backend.do_completion(&prompt).unwrap();
|
||||
let response = backend.do_completion(&prompt).await.unwrap();
|
||||
assert!(!response.insert_text.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use anyhow::Context;
|
||||
use indexmap::IndexSet;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use parking_lot::Mutex;
|
||||
use ropey::Rope;
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
@@ -15,8 +16,8 @@ use super::{MemoryBackend, Prompt, PromptForType};
|
||||
pub struct FileStore {
|
||||
crawl: bool,
|
||||
configuration: Configuration,
|
||||
file_map: HashMap<String, Rope>,
|
||||
accessed_files: IndexSet<String>,
|
||||
file_map: Mutex<HashMap<String, Rope>>,
|
||||
accessed_files: Mutex<IndexSet<String>>,
|
||||
}
|
||||
|
||||
// TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant
|
||||
@@ -37,8 +38,8 @@ impl FileStore {
|
||||
Self {
|
||||
crawl: file_store_config.crawl,
|
||||
configuration,
|
||||
file_map: HashMap::new(),
|
||||
accessed_files: IndexSet::new(),
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,8 +47,8 @@ impl FileStore {
|
||||
Self {
|
||||
crawl: false,
|
||||
configuration,
|
||||
file_map: HashMap::new(),
|
||||
accessed_files: IndexSet::new(),
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +61,7 @@ impl FileStore {
|
||||
let current_document_uri = position.text_document.uri.to_string();
|
||||
let mut rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(¤t_document_uri)
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
@@ -68,14 +70,16 @@ impl FileStore {
|
||||
// Add to our rope if we need to
|
||||
for file in self
|
||||
.accessed_files
|
||||
.lock()
|
||||
.iter()
|
||||
.filter(|f| **f != current_document_uri)
|
||||
{
|
||||
let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0);
|
||||
let needed = characters.saturating_sub(rope.len_chars());
|
||||
if needed == 0 {
|
||||
break;
|
||||
}
|
||||
let r = self.file_map.get(file).context("Error file not found")?;
|
||||
let file_map = self.file_map.lock();
|
||||
let r = file_map.get(file).context("Error file not found")?;
|
||||
let slice_max = needed.min(r.len_chars());
|
||||
let rope_str_slice = r
|
||||
.get_slice(0..slice_max)
|
||||
@@ -94,12 +98,13 @@ impl FileStore {
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
|
||||
let start = cursor_index.saturating_sub(characters / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (characters - (cursor_index - start)));
|
||||
@@ -137,15 +142,15 @@ impl FileStore {
|
||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||
{
|
||||
let max_length = tokens_to_estimated_characters(max_context_length);
|
||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
|
||||
if is_chat_enabled {
|
||||
rope.insert(cursor_index, "{CURSOR}");
|
||||
rope.insert(cursor_index, "<CURSOR>");
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end + "{CURSOR}".chars().count())
|
||||
.get_slice(start..end + "<CURSOR>".chars().count())
|
||||
.context("Error getting rope slice")?;
|
||||
rope_slice.to_string()
|
||||
} else {
|
||||
@@ -166,9 +171,8 @@ impl FileStore {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let start = cursor_index
|
||||
.checked_sub(tokens_to_estimated_characters(max_context_length))
|
||||
.unwrap_or(0);
|
||||
let start =
|
||||
cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
@@ -178,11 +182,16 @@ impl FileStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
@@ -193,8 +202,8 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
&mut self,
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
@@ -207,25 +216,25 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
&mut self,
|
||||
async fn opened_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let rope = Rope::from_str(¶ms.text_document.text);
|
||||
let uri = params.text_document.uri.to_string();
|
||||
self.file_map.insert(uri.clone(), rope);
|
||||
self.accessed_files.shift_insert(0, uri);
|
||||
self.file_map.lock().insert(uri.clone(), rope);
|
||||
self.accessed_files.lock().shift_insert(0, uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&mut self,
|
||||
async fn changed_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let uri = params.text_document.uri.to_string();
|
||||
let rope = self
|
||||
.file_map
|
||||
let mut file_map = self.file_map.lock();
|
||||
let rope = file_map
|
||||
.get_mut(&uri)
|
||||
.context("Error trying to get file that does not exist")?;
|
||||
for change in params.content_changes {
|
||||
@@ -241,15 +250,16 @@ impl MemoryBackend for FileStore {
|
||||
*rope = Rope::from_str(&change.text);
|
||||
}
|
||||
}
|
||||
self.accessed_files.shift_insert(0, uri);
|
||||
self.accessed_files.lock().shift_insert(0, uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
async fn renamed_file(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
||||
self.file_map.insert(file_rename.new_uri, rope);
|
||||
let mut file_map = self.file_map.lock();
|
||||
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||
file_map.insert(file_rename.new_uri, rope);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -26,22 +26,29 @@ pub enum PromptForType {
|
||||
Generate,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait MemoryBackend {
|
||||
fn init(&self) -> anyhow::Result<()> {
|
||||
async fn init(&self) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn build_prompt(
|
||||
&mut self,
|
||||
async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
async fn changed_text_document(
|
||||
&self,
|
||||
params: DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()>;
|
||||
async fn renamed_file(&self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt>;
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send + Sync> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
||||
@@ -55,3 +62,22 @@ impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This makes testing much easier. Every transformer backend takes in a prompt. When verifying they work, its
|
||||
// easier to just pass in a default prompt.
|
||||
#[cfg(test)]
|
||||
impl Prompt {
|
||||
pub fn default_with_cursor() -> Self {
|
||||
Self {
|
||||
context: r#"def test_context():\n pass"#.to_string(),
|
||||
code: r#"def test_code():\n <CURSOR>"#.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_without_cursor() -> Self {
|
||||
Self {
|
||||
context: r#"def test_context():\n pass"#.to_string(),
|
||||
code: r#"def test_code():\n "#.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ impl PostgresML {
|
||||
};
|
||||
// TODO: Think on the naming of the collection
|
||||
// Maybe filter on metadata or I'm not sure
|
||||
let collection = Collection::new("test-lsp-ai-2", Some(database_url))?;
|
||||
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
|
||||
// TODO: Review the pipeline
|
||||
let pipeline = Pipeline::new(
|
||||
"v1",
|
||||
@@ -50,7 +50,7 @@ impl PostgresML {
|
||||
"splitter": {
|
||||
"model": "recursive_character",
|
||||
"parameters": {
|
||||
"chunk_size": 512,
|
||||
"chunk_size": 1500,
|
||||
"chunk_overlap": 40
|
||||
}
|
||||
},
|
||||
@@ -90,7 +90,7 @@ impl PostgresML {
|
||||
.into_iter()
|
||||
.map(|path| {
|
||||
let text = std::fs::read_to_string(&path)
|
||||
.expect(format!("Error reading path: {}", path).as_str());
|
||||
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
|
||||
json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
@@ -121,24 +121,28 @@ impl PostgresML {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for PostgresML {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position)
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
&mut self,
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
// This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform
|
||||
let query = self
|
||||
.file_store
|
||||
.get_characters_around_position(position, 512)?;
|
||||
let res = self.runtime.block_on(
|
||||
self.collection.vector_search(
|
||||
let res = self
|
||||
.collection
|
||||
.vector_search_local(
|
||||
json!({
|
||||
"query": {
|
||||
"fields": {
|
||||
@@ -150,9 +154,9 @@ impl MemoryBackend for PostgresML {
|
||||
"limit": 5
|
||||
})
|
||||
.into(),
|
||||
&mut self.pipeline,
|
||||
),
|
||||
)?;
|
||||
&self.pipeline,
|
||||
)
|
||||
.await?;
|
||||
let context = res
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
@@ -176,8 +180,8 @@ impl MemoryBackend for PostgresML {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
&mut self,
|
||||
async fn opened_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let text = params.text_document.text.clone();
|
||||
@@ -185,68 +189,63 @@ impl MemoryBackend for PostgresML {
|
||||
let task_added_pipeline = self.added_pipeline;
|
||||
let mut task_collection = self.collection.clone();
|
||||
let mut task_pipeline = self.pipeline.clone();
|
||||
self.runtime.spawn(async move {
|
||||
if !task_added_pipeline {
|
||||
task_collection
|
||||
.add_pipeline(&mut task_pipeline)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
if !task_added_pipeline {
|
||||
task_collection
|
||||
.add_pipeline(&mut task_pipeline)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error upserting documents");
|
||||
self.file_store.opened_text_document(params).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn changed_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
self.debounce_tx.send(path)?;
|
||||
self.file_store.changed_text_document(params).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn renamed_file(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
let mut task_collection = self.collection.clone();
|
||||
let task_params = params.clone();
|
||||
for file in task_params.files {
|
||||
task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"id": file.old_uri
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error deleting file");
|
||||
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"id": file.new_uri,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error upserting documents");
|
||||
});
|
||||
self.file_store.opened_text_document(params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
self.debounce_tx.send(path)?;
|
||||
self.file_store.changed_text_document(params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
let mut task_collection = self.collection.clone();
|
||||
let task_params = params.clone();
|
||||
self.runtime.spawn(async move {
|
||||
for file in task_params.files {
|
||||
task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"id": file.old_uri
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error deleting file");
|
||||
let text =
|
||||
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": file.new_uri,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
});
|
||||
self.file_store.renamed_file(params)
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
self.file_store.renamed_file(params).await
|
||||
}
|
||||
}
|
||||
|
||||
100
src/memory_worker.rs
Normal file
100
src/memory_worker.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use lsp_types::{
|
||||
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
||||
TextDocumentPositionParams,
|
||||
};
|
||||
|
||||
use crate::memory_backends::{MemoryBackend, Prompt, PromptForType};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptRequest {
|
||||
position: TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
tx: tokio::sync::oneshot::Sender<Prompt>,
|
||||
}
|
||||
|
||||
impl PromptRequest {
|
||||
pub fn new(
|
||||
position: TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
tx: tokio::sync::oneshot::Sender<Prompt>,
|
||||
) -> Self {
|
||||
Self {
|
||||
position,
|
||||
prompt_for_type,
|
||||
tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterRequest {
|
||||
position: TextDocumentPositionParams,
|
||||
tx: tokio::sync::oneshot::Sender<String>,
|
||||
}
|
||||
|
||||
impl FilterRequest {
|
||||
pub fn new(
|
||||
position: TextDocumentPositionParams,
|
||||
tx: tokio::sync::oneshot::Sender<String>,
|
||||
) -> Self {
|
||||
Self { position, tx }
|
||||
}
|
||||
}
|
||||
|
||||
pub enum WorkerRequest {
|
||||
FilterText(FilterRequest),
|
||||
Prompt(PromptRequest),
|
||||
DidOpenTextDocument(DidOpenTextDocumentParams),
|
||||
DidChangeTextDocument(DidChangeTextDocumentParams),
|
||||
DidRenameFiles(RenameFilesParams),
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
memory_backend: Box<dyn MemoryBackend + Send + Sync>,
|
||||
rx: std::sync::mpsc::Receiver<WorkerRequest>,
|
||||
) -> anyhow::Result<()> {
|
||||
let memory_backend = Arc::new(memory_backend);
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
loop {
|
||||
let request = rx.recv()?;
|
||||
let thread_memory_backend = memory_backend.clone();
|
||||
runtime.spawn(async move {
|
||||
match request {
|
||||
WorkerRequest::FilterText(params) => {
|
||||
let filter_text = thread_memory_backend
|
||||
.get_filter_text(¶ms.position)
|
||||
.await
|
||||
.unwrap();
|
||||
params.tx.send(filter_text).unwrap();
|
||||
}
|
||||
WorkerRequest::Prompt(params) => {
|
||||
let prompt = thread_memory_backend
|
||||
.build_prompt(¶ms.position, params.prompt_for_type)
|
||||
.await
|
||||
.unwrap();
|
||||
params.tx.send(prompt).unwrap();
|
||||
}
|
||||
WorkerRequest::DidOpenTextDocument(params) => {
|
||||
thread_memory_backend
|
||||
.opened_text_document(params)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
WorkerRequest::DidChangeTextDocument(params) => {
|
||||
thread_memory_backend
|
||||
.changed_text_document(params)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
WorkerRequest::DidRenameFiles(params) => {
|
||||
thread_memory_backend.renamed_file(params).await.unwrap()
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
207
src/transformer_backends/anthropic.rs
Normal file
207
src/transformer_backends/anthropic.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
configuration::{self, ChatMessage},
|
||||
memory_backends::Prompt,
|
||||
transformer_worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
utils::format_chat_messages,
|
||||
};
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct Anthropic {
|
||||
configuration: configuration::Anthropic,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AnthropicChatMessage {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AnthropicChatResponse {
|
||||
content: Option<Vec<AnthropicChatMessage>>,
|
||||
error: Option<Value>,
|
||||
}
|
||||
|
||||
impl Anthropic {
|
||||
#[instrument]
|
||||
pub fn new(configuration: configuration::Anthropic) -> Self {
|
||||
Self { configuration }
|
||||
}
|
||||
|
||||
async fn get_chat(
|
||||
&self,
|
||||
system_prompt: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
eprintln!(
|
||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||
messages
|
||||
);
|
||||
let client = reqwest::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
let res: AnthropicChatResponse = client
|
||||
.post(
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions")?,
|
||||
)
|
||||
.header("x-api-key", token)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&json!({
|
||||
"model": self.configuration.model,
|
||||
"system": system_prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": self.configuration.top_p,
|
||||
"temperature": self.configuration.temperature,
|
||||
"messages": messages
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(mut content) = res.content {
|
||||
Ok(std::mem::take(&mut content[0].text))
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_get_chat(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
messages: &[ChatMessage],
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = format_chat_messages(messages, prompt);
|
||||
if messages[0].role != "system" {
|
||||
anyhow::bail!(
|
||||
"When using Anthropic, the first message in chat must have role = `system`"
|
||||
)
|
||||
}
|
||||
let system_prompt = messages.remove(0).content;
|
||||
self.get_chat(system_prompt, messages, max_tokens).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for Anthropic {
|
||||
#[instrument(skip(self))]
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let max_tokens = self.configuration.max_tokens.completion;
|
||||
let insert_text = match &self.configuration.chat.completion {
|
||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||
None => {
|
||||
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
|
||||
}
|
||||
};
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let max_tokens = self.configuration.max_tokens.generation;
|
||||
let generated_text = match &self.configuration.chat.generation {
|
||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||
None => {
|
||||
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
|
||||
}
|
||||
};
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn anthropic_chat_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::Anthropic = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
"chat": {
|
||||
"completion": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let anthropic = Anthropic::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = anthropic.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn anthropic_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::Anthropic = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
"chat": {
|
||||
"generation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let anthropic = Anthropic::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = anthropic.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -2,20 +2,21 @@ use anyhow::Context;
|
||||
use hf_hub::api::sync::ApiBuilder;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
configuration::{self},
|
||||
memory_backends::Prompt,
|
||||
template::apply_chat_template,
|
||||
utils::format_chat_messages,
|
||||
worker::{
|
||||
transformer_worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
utils::format_chat_messages,
|
||||
};
|
||||
|
||||
mod model;
|
||||
use model::Model;
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct LlamaCPP {
|
||||
model: Model,
|
||||
configuration: configuration::ModelGGUF,
|
||||
@@ -62,9 +63,10 @@ impl LlamaCPP {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
let prompt = &prompt.code;
|
||||
debug!("Prompt string for LLM: {}", prompt);
|
||||
@@ -75,7 +77,7 @@ impl TransformerBackend for LlamaCPP {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
let prompt = &prompt.code;
|
||||
@@ -86,7 +88,7 @@ impl TransformerBackend for LlamaCPP {
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate_stream(
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
_request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
use crate::{
|
||||
configuration::{Configuration, ValidTransformerBackend},
|
||||
memory_backends::Prompt,
|
||||
worker::{
|
||||
transformer_worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
};
|
||||
|
||||
mod anthropic;
|
||||
mod llama_cpp;
|
||||
mod openai;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait TransformerBackend {
|
||||
// Should all take an enum of chat messages or just a string for completion
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
||||
fn do_generate_stream(
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse>;
|
||||
}
|
||||
|
||||
impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send> {
|
||||
impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
||||
@@ -30,6 +31,9 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send> {
|
||||
ValidTransformerBackend::OpenAI(openai_config) => {
|
||||
Ok(Box::new(openai::OpenAI::new(openai_config)))
|
||||
}
|
||||
ValidTransformerBackend::Anthropic(anthropic_config) => {
|
||||
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
// Something more about what this file is
|
||||
// NOTE: When decoding responses from OpenAI compatbile services, we don't care about every field
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::instrument;
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
configuration,
|
||||
configuration::{self, ChatMessage},
|
||||
memory_backends::Prompt,
|
||||
worker::{
|
||||
transformer_worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
utils::{format_chat_messages, format_context_code},
|
||||
};
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct OpenAI {
|
||||
configuration: configuration::OpenAI,
|
||||
}
|
||||
@@ -22,7 +28,19 @@ struct OpenAICompletionsChoice {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAICompletionsResponse {
|
||||
choices: Vec<OpenAICompletionsChoice>,
|
||||
choices: Option<Vec<OpenAICompletionsChoice>>,
|
||||
error: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIChatChoices {
|
||||
message: ChatMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIChatResponse {
|
||||
choices: Option<Vec<OpenAIChatChoices>>,
|
||||
error: Option<Value>,
|
||||
}
|
||||
|
||||
impl OpenAI {
|
||||
@@ -31,18 +49,27 @@ impl OpenAI {
|
||||
Self { configuration }
|
||||
}
|
||||
|
||||
fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
|
||||
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
fn get_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
Ok(std::env::var(env_var_name)?)
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API")
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
|
||||
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: OpenAICompletionsResponse = client
|
||||
.post(&self.configuration.completions_endpoint)
|
||||
.post(
|
||||
self.configuration
|
||||
.completions_endpoint
|
||||
.as_ref()
|
||||
.context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
@@ -51,43 +78,231 @@ impl OpenAI {
|
||||
"max_tokens": max_tokens,
|
||||
"n": 1,
|
||||
"top_p": self.configuration.top_p,
|
||||
"top_k": self.configuration.top_k,
|
||||
"presence_penalty": self.configuration.presence_penalty,
|
||||
"frequency_penalty": self.configuration.frequency_penalty,
|
||||
"temperature": self.configuration.temperature,
|
||||
"echo": false,
|
||||
"prompt": prompt
|
||||
}))
|
||||
.send()?
|
||||
.json()?;
|
||||
eprintln!("**********RECEIVED REQUEST********");
|
||||
Ok(res.choices[0].text.clone())
|
||||
.send().await?
|
||||
.json().await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(mut choices) = res.choices {
|
||||
Ok(std::mem::take(&mut choices[0].text))
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_chat(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
eprintln!(
|
||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||
messages
|
||||
);
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: OpenAIChatResponse = client
|
||||
.post(
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&json!({
|
||||
"model": self.configuration.model,
|
||||
"max_tokens": max_tokens,
|
||||
"n": 1,
|
||||
"top_p": self.configuration.top_p,
|
||||
"presence_penalty": self.configuration.presence_penalty,
|
||||
"frequency_penalty": self.configuration.frequency_penalty,
|
||||
"temperature": self.configuration.temperature,
|
||||
"messages": messages
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(choices) = res.choices {
|
||||
Ok(choices[0].message.content.clone())
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_chat_completion(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
messages: Option<&Vec<ChatMessage>>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
match messages {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, max_tokens).await
|
||||
}
|
||||
None => {
|
||||
self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for OpenAI {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
||||
let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
||||
let max_tokens = self.configuration.max_tokens.completion;
|
||||
let messages = self
|
||||
.configuration
|
||||
.chat
|
||||
.as_ref()
|
||||
.and_then(|c| c.completion.as_ref());
|
||||
let insert_text = self
|
||||
.do_chat_completion(prompt, messages, max_tokens)
|
||||
.await?;
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
||||
let generated_text =
|
||||
self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
||||
let max_tokens = self.configuration.max_tokens.generation;
|
||||
let messages = self
|
||||
.configuration
|
||||
.chat
|
||||
.as_ref()
|
||||
.and_then(|c| c.generation.as_ref());
|
||||
let generated_text = self
|
||||
.do_chat_completion(prompt, messages, max_tokens)
|
||||
.await?;
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate_stream(
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_completion_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let response = openai.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"chat": {
|
||||
"completion": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_completion_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let response = openai.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
"chat": {
|
||||
"generation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
214
src/transformer_worker.rs
Normal file
214
src/transformer_worker.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use lsp_server::{Connection, Message, RequestId, Response};
|
||||
use lsp_types::{
|
||||
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
|
||||
Position, Range, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use crate::custom_requests::generate_stream::GenerateStreamParams;
|
||||
use crate::memory_backends::PromptForType;
|
||||
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
||||
use crate::transformer_backends::TransformerBackend;
|
||||
use crate::utils::ToResponseError;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn new(id: RequestId, params: CompletionParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GenerateRequest {
|
||||
id: RequestId,
|
||||
params: GenerateParams,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
pub fn new(id: RequestId, params: GenerateParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
// The generate stream is not yet ready but we don't want to remove it
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GenerateStreamRequest {
|
||||
id: RequestId,
|
||||
params: GenerateStreamParams,
|
||||
}
|
||||
|
||||
impl GenerateStreamRequest {
|
||||
pub fn new(id: RequestId, params: GenerateStreamParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum WorkerRequest {
|
||||
Completion(CompletionRequest),
|
||||
Generate(GenerateRequest),
|
||||
GenerateStream(GenerateStreamRequest),
|
||||
}
|
||||
|
||||
pub struct DoCompletionResponse {
|
||||
pub insert_text: String,
|
||||
}
|
||||
|
||||
pub struct DoGenerateResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub struct DoGenerateStreamResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
) -> anyhow::Result<()> {
|
||||
let transformer_backend = Arc::new(transformer_backend);
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let mut completion_request = last_worker_request.lock();
|
||||
std::mem::take(&mut *completion_request)
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let thread_connection = connection.clone();
|
||||
let thread_transformer_backend = transformer_backend.clone();
|
||||
let thread_memory_backend_tx = memory_backend_tx.clone();
|
||||
runtime.spawn(async move {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => match do_completion(
|
||||
thread_transformer_backend,
|
||||
thread_memory_backend_tx,
|
||||
&request,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::Generate(request) => match do_generate(
|
||||
thread_transformer_backend,
|
||||
thread_memory_backend_tx,
|
||||
&request,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::GenerateStream(_) => {
|
||||
panic!("Streaming is not supported yet")
|
||||
}
|
||||
};
|
||||
thread_connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending message");
|
||||
});
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_completion(
|
||||
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
request: &CompletionRequest,
|
||||
) -> anyhow::Result<Response> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||
request.params.text_document_position.clone(),
|
||||
PromptForType::Completion,
|
||||
tx,
|
||||
)))?;
|
||||
let prompt = rx.await?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::FilterText(
|
||||
FilterRequest::new(request.params.text_document_position.clone(), tx),
|
||||
))?;
|
||||
let filter_text = rx.await?;
|
||||
|
||||
let response = transformer_backend.do_completion(&prompt).await?;
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
response.insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {}", response.insert_text),
|
||||
filter_text: Some(filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_generate(
|
||||
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
request: &GenerateRequest,
|
||||
) -> anyhow::Result<Response> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||
request.params.text_document_position.clone(),
|
||||
PromptForType::Completion,
|
||||
tx,
|
||||
)))?;
|
||||
let prompt = rx.await?;
|
||||
|
||||
let response = transformer_backend.do_generate(&prompt).await?;
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
10
src/utils.rs
10
src/utils.rs
@@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
|
||||
tokens * 4
|
||||
}
|
||||
|
||||
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
|
||||
pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec<ChatMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| ChatMessage {
|
||||
role: m.role.to_owned(),
|
||||
content: m
|
||||
.content
|
||||
.replace("{context}", &prompt.context)
|
||||
.replace("{code}", &prompt.code),
|
||||
.replace("{CONTEXT}", &prompt.context)
|
||||
.replace("{CODE}", &prompt.code),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn format_context_code(context: &str, code: &str) -> String {
|
||||
format!("{context}\n\n{code}")
|
||||
}
|
||||
|
||||
191
src/worker.rs
191
src/worker.rs
@@ -1,191 +0,0 @@
|
||||
use lsp_server::{Connection, Message, RequestId, Response};
|
||||
use lsp_types::{
|
||||
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
|
||||
Position, Range, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use crate::custom_requests::generate_stream::GenerateStreamParams;
|
||||
use crate::memory_backends::{MemoryBackend, PromptForType};
|
||||
use crate::transformer_backends::TransformerBackend;
|
||||
use crate::utils::ToResponseError;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn new(id: RequestId, params: CompletionParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GenerateRequest {
|
||||
id: RequestId,
|
||||
params: GenerateParams,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
pub fn new(id: RequestId, params: GenerateParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
// The generate stream is not yet ready but we don't want to remove it
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GenerateStreamRequest {
|
||||
id: RequestId,
|
||||
params: GenerateStreamParams,
|
||||
}
|
||||
|
||||
impl GenerateStreamRequest {
|
||||
pub fn new(id: RequestId, params: GenerateStreamParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum WorkerRequest {
|
||||
Completion(CompletionRequest),
|
||||
Generate(GenerateRequest),
|
||||
GenerateStream(GenerateStreamRequest),
|
||||
}
|
||||
|
||||
pub struct DoCompletionResponse {
|
||||
pub insert_text: String,
|
||||
}
|
||||
|
||||
pub struct DoGenerateResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub struct DoGenerateStreamResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub struct Worker {
|
||||
transformer_backend: Box<dyn TransformerBackend>,
|
||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
pub fn new(
|
||||
transformer_backend: Box<dyn TransformerBackend>,
|
||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
) -> Self {
|
||||
Self {
|
||||
transformer_backend,
|
||||
memory_backend,
|
||||
last_worker_request,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self.memory_backend.lock().build_prompt(
|
||||
&request.params.text_document_position,
|
||||
PromptForType::Completion,
|
||||
)?;
|
||||
let filter_text = self
|
||||
.memory_backend
|
||||
.lock()
|
||||
.get_filter_text(&request.params.text_document_position)?;
|
||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
response.insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {}", response.insert_text),
|
||||
filter_text: Some(filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self.memory_backend.lock().build_prompt(
|
||||
&request.params.text_document_position,
|
||||
PromptForType::Generate,
|
||||
)?;
|
||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(self) {
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let mut completion_request = self.last_worker_request.lock();
|
||||
std::mem::take(&mut *completion_request)
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => match self.do_completion(&request) {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::Generate(request) => match self.do_generate(&request) {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"),
|
||||
};
|
||||
self.connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending message");
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
Submodule submodules/postgresml updated: 0842673804...a16ff700c1
Reference in New Issue
Block a user