Merge pull request #1 from SilasMarvin/silas-async-overhaul

Silas async overhaul
This commit is contained in:
Silas Marvin
2024-03-23 19:02:47 -07:00
committed by GitHub
16 changed files with 1058 additions and 430 deletions

94
Cargo.lock generated
View File

@@ -110,9 +110,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.80"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
[[package]]
name = "assert_cmd"
@@ -131,9 +131,9 @@ dependencies = [
[[package]]
name = "async-trait"
version = "0.1.77"
version = "0.1.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85"
dependencies = [
"proc-macro2",
"quote",
@@ -149,16 +149,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic-write-file"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8204db279bf648d64fe845bd8840f78b39c8132ed4d6a4194c3b10d4b4cfb0b"
dependencies = [
"nix",
"rand",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@@ -532,6 +522,16 @@ dependencies = [
"typenum",
]
[[package]]
name = "ctrlc"
version = "3.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345"
dependencies = [
"nix",
"windows-sys 0.52.0",
]
[[package]]
name = "darling"
version = "0.14.4"
@@ -1410,6 +1410,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "llama-cpp-2"
version = "0.1.34"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
dependencies = [
"llama-cpp-sys-2",
"thiserror",
@@ -1419,6 +1420,7 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.34"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-apply-chat-template#f810fea8a8a57fd9693de6a77b35b05a1ae77064"
dependencies = [
"bindgen",
"cc",
@@ -1465,6 +1467,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assert_cmd",
"async-trait",
"directories",
"hf-hub",
"ignore",
@@ -1956,6 +1959,7 @@ dependencies = [
"chrono",
"clap",
"colored",
"ctrlc",
"futures",
"indicatif",
"inquire",
@@ -1963,6 +1967,7 @@ dependencies = [
"itertools 0.10.5",
"lopdf",
"md5",
"once_cell",
"parking_lot",
"regex",
"reqwest",
@@ -2075,9 +2080,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.78"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
dependencies = [
"unicode-ident",
]
@@ -2224,9 +2229,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
[[package]]
name = "reqwest"
version = "0.11.25"
version = "0.11.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946"
checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2"
dependencies = [
"base64 0.21.7",
"bytes",
@@ -2769,9 +2774,9 @@ dependencies = [
[[package]]
name = "sqlx"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf"
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
dependencies = [
"sqlx-core",
"sqlx-macros",
@@ -2782,9 +2787,9 @@ dependencies = [
[[package]]
name = "sqlx-core"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd"
checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6"
dependencies = [
"ahash",
"atoi",
@@ -2792,7 +2797,6 @@ dependencies = [
"bytes",
"crc",
"crossbeam-queue",
"dotenvy",
"either",
"event-listener",
"futures-channel",
@@ -2827,9 +2831,9 @@ dependencies = [
[[package]]
name = "sqlx-macros"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5"
checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
dependencies = [
"proc-macro2",
"quote",
@@ -2840,11 +2844,10 @@ dependencies = [
[[package]]
name = "sqlx-macros-core"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841"
checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
dependencies = [
"atomic-write-file",
"dotenvy",
"either",
"heck",
@@ -2867,9 +2870,9 @@ dependencies = [
[[package]]
name = "sqlx-mysql"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4"
checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418"
dependencies = [
"atoi",
"base64 0.21.7",
@@ -2911,9 +2914,9 @@ dependencies = [
[[package]]
name = "sqlx-postgres"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24"
checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e"
dependencies = [
"atoi",
"base64 0.21.7",
@@ -2938,7 +2941,6 @@ dependencies = [
"rand",
"serde",
"serde_json",
"sha1",
"sha2",
"smallvec",
"sqlx-core",
@@ -2952,9 +2954,9 @@ dependencies = [
[[package]]
name = "sqlx-sqlite"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [
"atoi",
"flume",
@@ -3051,20 +3053,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "system-configuration"
version = "0.6.0"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 2.4.2",
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
@@ -3090,18 +3092,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
[[package]]
name = "thiserror"
version = "1.0.57"
version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.57"
version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [
"proc-macro2",
"quote",
@@ -3631,9 +3633,9 @@ dependencies = [
[[package]]
name = "whoami"
version = "1.5.0"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e"
checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9"
dependencies = [
"redox_syscall",
"wasite",

View File

@@ -20,7 +20,7 @@ parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
# llama-cpp-2 = "0.1.31"
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
llama-cpp-2 = { git = "https://github.com/SilasMarvin/llama-cpp-rs", branch = "silas-apply-chat-template" }
minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
@@ -30,6 +30,7 @@ ignore = "0.4.22"
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
indexmap = "2.2.5"
async-trait = "0.1.78"
[features]
default = []

View File

@@ -19,6 +19,7 @@ pub enum ValidMemoryBackend {
pub enum ValidTransformerBackend {
LlamaCPP(ModelGGUF),
OpenAI(OpenAI),
Anthropic(Anthropic),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -36,6 +37,7 @@ pub struct Chat {
}
#[derive(Clone, Debug, Deserialize)]
#[allow(clippy::upper_case_acronyms)]
pub struct FIM {
pub start: String,
pub middle: String,
@@ -129,10 +131,6 @@ const fn openai_top_p_default() -> f32 {
0.95
}
const fn openai_top_k_default() -> usize {
40
}
const fn openai_presence_penalty() -> f32 {
0.
}
@@ -149,13 +147,15 @@ const fn openai_max_context() -> usize {
DEFAULT_OPENAI_MAX_CONTEXT
}
#[derive(Clone, Debug, Default, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct OpenAI {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
pub auth_token: Option<String>,
// The completions endpoint
pub completions_endpoint: String,
pub completions_endpoint: Option<String>,
// The chat endpoint
pub chat_endpoint: Option<String>,
// The model name
pub model: String,
// Fill in the middle support
@@ -168,8 +168,6 @@ pub struct OpenAI {
// Other available args
#[serde(default = "openai_top_p_default")]
pub top_p: f32,
#[serde(default = "openai_top_k_default")]
pub top_k: usize,
#[serde(default = "openai_presence_penalty")]
pub presence_penalty: f32,
#[serde(default = "openai_frequency_penalty")]
@@ -180,9 +178,35 @@ pub struct OpenAI {
max_context: usize,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Anthropic {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
pub auth_token: Option<String>,
// The completions endpoint
pub completions_endpoint: Option<String>,
// The chat endpoint
pub chat_endpoint: Option<String>,
// The model name
pub model: String,
// Fill in the middle support
pub fim: Option<FIM>,
// The maximum number of new tokens to generate
#[serde(default)]
pub max_tokens: MaxTokens,
// Chat args
pub chat: Chat,
// System prompt
#[serde(default = "openai_top_p_default")]
pub top_p: f32,
#[serde(default = "openai_temperature")]
pub temperature: f32,
}
#[derive(Clone, Debug, Deserialize)]
struct ValidTransformerConfiguration {
openai: Option<OpenAI>,
anthropic: Option<Anthropic>,
model_gguf: Option<ModelGGUF>,
}
@@ -190,6 +214,7 @@ impl Default for ValidTransformerConfiguration {
fn default() -> Self {
Self {
model_gguf: Some(ModelGGUF::default()),
anthropic: None,
openai: None,
}
}

View File

@@ -6,25 +6,31 @@ use lsp_types::{
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use std::{
sync::{mpsc, Arc},
thread,
};
use tracing::error;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod configuration;
mod custom_requests;
mod memory_backends;
mod memory_worker;
mod template;
mod transformer_backends;
mod transformer_worker;
mod utils;
mod worker;
use configuration::Configuration;
use custom_requests::generate::Generate;
use memory_backends::MemoryBackend;
use transformer_backends::TransformerBackend;
use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
use transformer_worker::{CompletionRequest, GenerateRequest, WorkerRequest};
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
use crate::{
custom_requests::generate_stream::GenerateStream, transformer_worker::GenerateStreamRequest,
};
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD
@@ -52,7 +58,7 @@ fn main() -> Result<()> {
.init();
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(&ServerCapabilities {
let server_capabilities = serde_json::to_value(ServerCapabilities {
completion_provider: Some(CompletionOptions::default()),
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::INCREMENTAL,
@@ -66,38 +72,40 @@ fn main() -> Result<()> {
Ok(())
}
// This main loop is tricky
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let args = Configuration::new(args)?;
// Set the transformer_backend
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
// Set the memory_backend
let memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>> =
Arc::new(Mutex::new(args.clone().try_into()?));
// Build our configuration
let configuration = Configuration::new(args)?;
// Wrap the connection for sharing between threads
let connection = Arc::new(connection);
// How we communicate between the worker and receiver threads
// Our channel we use to communicate with our transformer_worker
let last_worker_request = Arc::new(Mutex::new(None));
// Setup our memory_worker
// TODO: Setup some kind of error handler
// Set the memory_backend
// The channel we use to communicate with our memory_worker
let (memory_tx, memory_rx) = mpsc::channel();
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = configuration.clone().try_into()?;
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
// Setup our transformer_worker
// Thread local variables
let thread_memory_backend = memory_backend.clone();
// TODO: Setup some kind of handler for errors here
// Set the transformer_backend
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
configuration.clone().try_into()?;
let thread_last_worker_request = last_worker_request.clone();
let thread_connection = connection.clone();
let thread_memory_tx = memory_tx.clone();
thread::spawn(move || {
Worker::new(
transformer_worker::run(
transformer_backend,
thread_memory_backend,
thread_memory_tx,
thread_last_worker_request,
thread_connection,
)
.run();
});
for msg in &connection.receiver {
@@ -143,13 +151,13 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
Message::Notification(not) => {
if notification_is::<lsp_types::notification::DidOpenTextDocument>(&not) {
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
memory_backend.lock().opened_text_document(params)?;
memory_tx.send(memory_worker::WorkerRequest::DidOpenTextDocument(params))?;
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(&not) {
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
memory_backend.lock().changed_text_document(params)?;
memory_tx.send(memory_worker::WorkerRequest::DidChangeTextDocument(params))?;
} else if notification_is::<lsp_types::notification::DidRenameFiles>(&not) {
let params: RenameFilesParams = serde_json::from_value(not.params)?;
memory_backend.lock().renamed_file(params)?;
memory_tx.send(memory_worker::WorkerRequest::DidRenameFiles(params))?;
}
}
_ => (),
@@ -170,18 +178,19 @@ mod tests {
//////////////////////////////////////
//////////////////////////////////////
#[test]
fn completion_with_default_arguments() {
#[tokio::test]
async fn completion_with_default_arguments() {
let args = json!({});
let configuration = Configuration::new(args).unwrap();
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
let backend: Box<dyn TransformerBackend + Send + Sync> =
configuration.clone().try_into().unwrap();
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
let response = backend.do_completion(&prompt).unwrap();
let response = backend.do_completion(&prompt).await.unwrap();
assert!(!response.insert_text.is_empty())
}
#[test]
fn completion_with_custom_gguf_model() {
#[tokio::test]
async fn completion_with_custom_gguf_model() {
let args = json!({
"initializationOptions": {
"memory": {
@@ -230,9 +239,10 @@ mod tests {
}
});
let configuration = Configuration::new(args).unwrap();
let backend: Box<dyn TransformerBackend + Send> = configuration.clone().try_into().unwrap();
let backend: Box<dyn TransformerBackend + Send + Sync> =
configuration.clone().try_into().unwrap();
let prompt = Prompt::new("".to_string(), "def fibn".to_string());
let response = backend.do_completion(&prompt).unwrap();
let response = backend.do_completion(&prompt).await.unwrap();
assert!(!response.insert_text.is_empty());
}
}

View File

@@ -1,8 +1,9 @@
use anyhow::Context;
use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use ropey::Rope;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tracing::instrument;
use crate::{
@@ -15,8 +16,8 @@ use super::{MemoryBackend, Prompt, PromptForType};
pub struct FileStore {
crawl: bool,
configuration: Configuration,
file_map: HashMap<String, Rope>,
accessed_files: IndexSet<String>,
file_map: Mutex<HashMap<String, Rope>>,
accessed_files: Mutex<IndexSet<String>>,
}
// TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant
@@ -37,8 +38,8 @@ impl FileStore {
Self {
crawl: file_store_config.crawl,
configuration,
file_map: HashMap::new(),
accessed_files: IndexSet::new(),
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
}
}
@@ -46,8 +47,8 @@ impl FileStore {
Self {
crawl: false,
configuration,
file_map: HashMap::new(),
accessed_files: IndexSet::new(),
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
}
}
@@ -60,6 +61,7 @@ impl FileStore {
let current_document_uri = position.text_document.uri.to_string();
let mut rope = self
.file_map
.lock()
.get(&current_document_uri)
.context("Error file not found")?
.clone();
@@ -68,14 +70,16 @@ impl FileStore {
// Add to our rope if we need to
for file in self
.accessed_files
.lock()
.iter()
.filter(|f| **f != current_document_uri)
{
let needed = characters.checked_sub(rope.len_chars()).unwrap_or(0);
let needed = characters.saturating_sub(rope.len_chars());
if needed == 0 {
break;
}
let r = self.file_map.get(file).context("Error file not found")?;
let file_map = self.file_map.lock();
let r = file_map.get(file).context("Error file not found")?;
let slice_max = needed.min(r.len_chars());
let rope_str_slice = r
.get_slice(0..slice_max)
@@ -94,12 +98,13 @@ impl FileStore {
) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
let start = cursor_index.saturating_sub(characters / 2);
let end = rope
.len_chars()
.min(cursor_index + (characters - (cursor_index - start)));
@@ -137,15 +142,15 @@ impl FileStore {
if is_chat_enabled || rope.len_chars() != cursor_index =>
{
let max_length = tokens_to_estimated_characters(max_context_length);
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
if is_chat_enabled {
rope.insert(cursor_index, "{CURSOR}");
rope.insert(cursor_index, "<CURSOR>");
let rope_slice = rope
.get_slice(start..end + "{CURSOR}".chars().count())
.get_slice(start..end + "<CURSOR>".chars().count())
.context("Error getting rope slice")?;
rope_slice.to_string()
} else {
@@ -166,9 +171,8 @@ impl FileStore {
}
}
_ => {
let start = cursor_index
.checked_sub(tokens_to_estimated_characters(max_context_length))
.unwrap_or(0);
let start =
cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
@@ -178,11 +182,16 @@ impl FileStore {
}
}
#[async_trait::async_trait]
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
@@ -193,8 +202,8 @@ impl MemoryBackend for FileStore {
}
#[instrument(skip(self))]
fn build_prompt(
&mut self,
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
@@ -207,25 +216,25 @@ impl MemoryBackend for FileStore {
}
#[instrument(skip(self))]
fn opened_text_document(
&mut self,
async fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let rope = Rope::from_str(&params.text_document.text);
let uri = params.text_document.uri.to_string();
self.file_map.insert(uri.clone(), rope);
self.accessed_files.shift_insert(0, uri);
self.file_map.lock().insert(uri.clone(), rope);
self.accessed_files.lock().shift_insert(0, uri);
Ok(())
}
#[instrument(skip(self))]
fn changed_text_document(
&mut self,
async fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let uri = params.text_document.uri.to_string();
let rope = self
.file_map
let mut file_map = self.file_map.lock();
let rope = file_map
.get_mut(&uri)
.context("Error trying to get file that does not exist")?;
for change in params.content_changes {
@@ -241,15 +250,16 @@ impl MemoryBackend for FileStore {
*rope = Rope::from_str(&change.text);
}
}
self.accessed_files.shift_insert(0, uri);
self.accessed_files.lock().shift_insert(0, uri);
Ok(())
}
#[instrument(skip(self))]
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
async fn renamed_file(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files {
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
self.file_map.insert(file_rename.new_uri, rope);
let mut file_map = self.file_map.lock();
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
file_map.insert(file_rename.new_uri, rope);
}
}
Ok(())

View File

@@ -26,22 +26,29 @@ pub enum PromptForType {
Generate,
}
#[async_trait::async_trait]
pub trait MemoryBackend {
fn init(&self) -> anyhow::Result<()> {
async fn init(&self) -> anyhow::Result<()> {
Ok(())
}
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
fn build_prompt(
&mut self,
async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
async fn changed_text_document(
&self,
params: DidChangeTextDocumentParams,
) -> anyhow::Result<()>;
async fn renamed_file(&self, params: RenameFilesParams) -> anyhow::Result<()>;
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt>;
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String>;
}
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send + Sync> {
type Error = anyhow::Error;
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
@@ -55,3 +62,22 @@ impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
}
}
}
// This makes testing much easier. Every transformer backend takes in a prompt. When verifying they work, its
// easier to just pass in a default prompt.
#[cfg(test)]
impl Prompt {
pub fn default_with_cursor() -> Self {
Self {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n <CURSOR>"#.to_string(),
}
}
pub fn default_without_cursor() -> Self {
Self {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n "#.to_string(),
}
}
}

View File

@@ -40,7 +40,7 @@ impl PostgresML {
};
// TODO: Think on the naming of the collection
// Maybe filter on metadata or I'm not sure
let collection = Collection::new("test-lsp-ai-2", Some(database_url))?;
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
// TODO: Review the pipeline
let pipeline = Pipeline::new(
"v1",
@@ -50,7 +50,7 @@ impl PostgresML {
"splitter": {
"model": "recursive_character",
"parameters": {
"chunk_size": 512,
"chunk_size": 1500,
"chunk_overlap": 40
}
},
@@ -90,7 +90,7 @@ impl PostgresML {
.into_iter()
.map(|path| {
let text = std::fs::read_to_string(&path)
.expect(format!("Error reading path: {}", path).as_str());
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
json!({
"id": path,
"text": text
@@ -121,24 +121,28 @@ impl PostgresML {
}
}
#[async_trait::async_trait]
impl MemoryBackend for PostgresML {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
self.file_store.get_filter_text(position)
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String> {
self.file_store.get_filter_text(position).await
}
#[instrument(skip(self))]
fn build_prompt(
&mut self,
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
// This is blocking, but that is ok as we only query for it from the worker when we are actually doing a transform
let query = self
.file_store
.get_characters_around_position(position, 512)?;
let res = self.runtime.block_on(
self.collection.vector_search(
let res = self
.collection
.vector_search_local(
json!({
"query": {
"fields": {
@@ -150,9 +154,9 @@ impl MemoryBackend for PostgresML {
"limit": 5
})
.into(),
&mut self.pipeline,
),
)?;
&self.pipeline,
)
.await?;
let context = res
.into_iter()
.map(|c| {
@@ -176,8 +180,8 @@ impl MemoryBackend for PostgresML {
}
#[instrument(skip(self))]
fn opened_text_document(
&mut self,
async fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let text = params.text_document.text.clone();
@@ -185,68 +189,63 @@ impl MemoryBackend for PostgresML {
let task_added_pipeline = self.added_pipeline;
let mut task_collection = self.collection.clone();
let mut task_pipeline = self.pipeline.clone();
self.runtime.spawn(async move {
if !task_added_pipeline {
task_collection
.add_pipeline(&mut task_pipeline)
.await
.expect("PGML - Error adding pipeline to collection");
}
if !task_added_pipeline {
task_collection
.add_pipeline(&mut task_pipeline)
.await
.expect("PGML - Error adding pipeline to collection");
}
task_collection
.upsert_documents(
vec![json!({
"id": path,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error upserting documents");
self.file_store.opened_text_document(params).await
}
#[instrument(skip(self))]
async fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let path = params.text_document.uri.path().to_owned();
self.debounce_tx.send(path)?;
self.file_store.changed_text_document(params).await
}
#[instrument(skip(self))]
async fn renamed_file(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
let mut task_collection = self.collection.clone();
let task_params = params.clone();
for file in task_params.files {
task_collection
.delete_documents(
json!({
"id": file.old_uri
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
task_collection
.upsert_documents(
vec![json!({
"id": path,
"id": file.new_uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error upserting documents");
});
self.file_store.opened_text_document(params)
}
#[instrument(skip(self))]
fn changed_text_document(
&mut self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let path = params.text_document.uri.path().to_owned();
self.debounce_tx.send(path)?;
self.file_store.changed_text_document(params)
}
#[instrument(skip(self))]
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
let mut task_collection = self.collection.clone();
let task_params = params.clone();
self.runtime.spawn(async move {
for file in task_params.files {
task_collection
.delete_documents(
json!({
"id": file.old_uri
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text =
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
task_collection
.upsert_documents(
vec![json!({
"id": file.new_uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
}
});
self.file_store.renamed_file(params)
.expect("PGML - Error adding pipeline to collection");
}
self.file_store.renamed_file(params).await
}
}

100
src/memory_worker.rs Normal file
View File

@@ -0,0 +1,100 @@
use std::sync::Arc;
use lsp_types::{
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
TextDocumentPositionParams,
};
use crate::memory_backends::{MemoryBackend, Prompt, PromptForType};
#[derive(Debug)]
pub struct PromptRequest {
position: TextDocumentPositionParams,
prompt_for_type: PromptForType,
tx: tokio::sync::oneshot::Sender<Prompt>,
}
impl PromptRequest {
pub fn new(
position: TextDocumentPositionParams,
prompt_for_type: PromptForType,
tx: tokio::sync::oneshot::Sender<Prompt>,
) -> Self {
Self {
position,
prompt_for_type,
tx,
}
}
}
#[derive(Debug)]
pub struct FilterRequest {
position: TextDocumentPositionParams,
tx: tokio::sync::oneshot::Sender<String>,
}
impl FilterRequest {
pub fn new(
position: TextDocumentPositionParams,
tx: tokio::sync::oneshot::Sender<String>,
) -> Self {
Self { position, tx }
}
}
pub enum WorkerRequest {
FilterText(FilterRequest),
Prompt(PromptRequest),
DidOpenTextDocument(DidOpenTextDocumentParams),
DidChangeTextDocument(DidChangeTextDocumentParams),
DidRenameFiles(RenameFilesParams),
}
pub fn run(
memory_backend: Box<dyn MemoryBackend + Send + Sync>,
rx: std::sync::mpsc::Receiver<WorkerRequest>,
) -> anyhow::Result<()> {
let memory_backend = Arc::new(memory_backend);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()?;
loop {
let request = rx.recv()?;
let thread_memory_backend = memory_backend.clone();
runtime.spawn(async move {
match request {
WorkerRequest::FilterText(params) => {
let filter_text = thread_memory_backend
.get_filter_text(&params.position)
.await
.unwrap();
params.tx.send(filter_text).unwrap();
}
WorkerRequest::Prompt(params) => {
let prompt = thread_memory_backend
.build_prompt(&params.position, params.prompt_for_type)
.await
.unwrap();
params.tx.send(prompt).unwrap();
}
WorkerRequest::DidOpenTextDocument(params) => {
thread_memory_backend
.opened_text_document(params)
.await
.unwrap();
}
WorkerRequest::DidChangeTextDocument(params) => {
thread_memory_backend
.changed_text_document(params)
.await
.unwrap();
}
WorkerRequest::DidRenameFiles(params) => {
thread_memory_backend.renamed_file(params).await.unwrap()
}
}
});
}
}

View File

@@ -0,0 +1,207 @@
use anyhow::Context;
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::instrument;
use crate::{
configuration::{self, ChatMessage},
memory_backends::Prompt,
transformer_worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
utils::format_chat_messages,
};
use super::TransformerBackend;
pub struct Anthropic {
configuration: configuration::Anthropic,
}
#[derive(Deserialize)]
struct AnthropicChatMessage {
text: String,
}
#[derive(Deserialize)]
struct AnthropicChatResponse {
content: Option<Vec<AnthropicChatMessage>>,
error: Option<Value>,
}
impl Anthropic {
#[instrument]
pub fn new(configuration: configuration::Anthropic) -> Self {
Self { configuration }
}
async fn get_chat(
&self,
system_prompt: String,
messages: Vec<ChatMessage>,
max_tokens: usize,
) -> anyhow::Result<String> {
eprintln!(
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
messages
);
let client = reqwest::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
} else if let Some(token) = &self.configuration.auth_token {
token.to_string()
} else {
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
};
let res: AnthropicChatResponse = client
.post(
self.configuration
.chat_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use completions")?,
)
.header("x-api-key", token)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"model": self.configuration.model,
"system": system_prompt,
"max_tokens": max_tokens,
"top_p": self.configuration.top_p,
"temperature": self.configuration.temperature,
"messages": messages
}))
.send()
.await?
.json()
.await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(mut content) = res.content {
Ok(std::mem::take(&mut content[0].text))
} else {
anyhow::bail!("Uknown error while making request to OpenAI")
}
}
async fn do_get_chat(
&self,
prompt: &Prompt,
messages: &[ChatMessage],
max_tokens: usize,
) -> anyhow::Result<String> {
let mut messages = format_chat_messages(messages, prompt);
if messages[0].role != "system" {
anyhow::bail!(
"When using Anthropic, the first message in chat must have role = `system`"
)
}
let system_prompt = messages.remove(0).content;
self.get_chat(system_prompt, messages, max_tokens).await
}
}
#[async_trait::async_trait]
impl TransformerBackend for Anthropic {
#[instrument(skip(self))]
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
eprintln!("--------------{:?}---------------", prompt);
let max_tokens = self.configuration.max_tokens.completion;
let insert_text = match &self.configuration.chat.completion {
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
None => {
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
}
};
Ok(DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
eprintln!("--------------{:?}---------------", prompt);
let max_tokens = self.configuration.max_tokens.generation;
let generated_text = match &self.configuration.chat.generation {
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
None => {
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
}
};
Ok(DoGenerateResponse { generated_text })
}
#[instrument(skip(self))]
async fn do_generate_stream(
&self,
request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse> {
unimplemented!()
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn anthropic_chat_do_completion() -> anyhow::Result<()> {
let configuration: configuration::Anthropic = serde_json::from_value(json!({
"chat_endpoint": "https://api.anthropic.com/v1/messages",
"model": "claude-3-haiku-20240307",
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
"chat": {
"completion": [
{
"role": "system",
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
},
{
"role": "user",
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
}
],
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let anthropic = Anthropic::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = anthropic.do_completion(&prompt).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn anthropic_chat_do_generate() -> anyhow::Result<()> {
let configuration: configuration::Anthropic = serde_json::from_value(json!({
"chat_endpoint": "https://api.anthropic.com/v1/messages",
"model": "claude-3-haiku-20240307",
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
"chat": {
"generation": [
{
"role": "system",
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
},
{
"role": "user",
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
}
]
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let anthropic = Anthropic::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = anthropic.do_generate(&prompt).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
}

View File

@@ -2,20 +2,21 @@ use anyhow::Context;
use hf_hub::api::sync::ApiBuilder;
use tracing::{debug, instrument};
use super::TransformerBackend;
use crate::{
configuration::{self},
memory_backends::Prompt,
template::apply_chat_template,
utils::format_chat_messages,
worker::{
transformer_worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
utils::format_chat_messages,
};
mod model;
use model::Model;
use super::TransformerBackend;
pub struct LlamaCPP {
model: Model,
configuration: configuration::ModelGGUF,
@@ -62,9 +63,10 @@ impl LlamaCPP {
}
}
#[async_trait::async_trait]
impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
// let prompt = self.get_prompt_string(prompt)?;
let prompt = &prompt.code;
debug!("Prompt string for LLM: {}", prompt);
@@ -75,7 +77,7 @@ impl TransformerBackend for LlamaCPP {
}
#[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
// let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
let prompt = &prompt.code;
@@ -86,7 +88,7 @@ impl TransformerBackend for LlamaCPP {
}
#[instrument(skip(self))]
fn do_generate_stream(
async fn do_generate_stream(
&self,
_request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse> {

View File

@@ -1,25 +1,26 @@
use crate::{
configuration::{Configuration, ValidTransformerBackend},
memory_backends::Prompt,
worker::{
transformer_worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
};
mod anthropic;
mod llama_cpp;
mod openai;
#[async_trait::async_trait]
pub trait TransformerBackend {
// Should all take an enum of chat messages or just a string for completion
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
fn do_generate_stream(
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
async fn do_generate_stream(
&self,
request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse>;
}
impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send> {
impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
type Error = anyhow::Error;
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
@@ -30,6 +31,9 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send> {
ValidTransformerBackend::OpenAI(openai_config) => {
Ok(Box::new(openai::OpenAI::new(openai_config)))
}
ValidTransformerBackend::Anthropic(anthropic_config) => {
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
}
}
}
}

View File

@@ -1,16 +1,22 @@
// Something more about what this file is
// NOTE: When decoding responses from OpenAI compatbile services, we don't care about every field
use anyhow::Context;
use serde::Deserialize;
use serde_json::json;
use serde_json::{json, Value};
use tracing::instrument;
use super::TransformerBackend;
use crate::{
configuration,
configuration::{self, ChatMessage},
memory_backends::Prompt,
worker::{
transformer_worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
utils::{format_chat_messages, format_context_code},
};
use super::TransformerBackend;
pub struct OpenAI {
configuration: configuration::OpenAI,
}
@@ -22,7 +28,19 @@ struct OpenAICompletionsChoice {
#[derive(Deserialize)]
struct OpenAICompletionsResponse {
choices: Vec<OpenAICompletionsChoice>,
choices: Option<Vec<OpenAICompletionsChoice>>,
error: Option<Value>,
}
#[derive(Deserialize)]
struct OpenAIChatChoices {
message: ChatMessage,
}
#[derive(Deserialize)]
struct OpenAIChatResponse {
choices: Option<Vec<OpenAIChatChoices>>,
error: Option<Value>,
}
impl OpenAI {
@@ -31,18 +49,27 @@ impl OpenAI {
Self { configuration }
}
fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
let client = reqwest::blocking::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
fn get_token(&self) -> anyhow::Result<String> {
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
Ok(std::env::var(env_var_name)?)
} else if let Some(token) = &self.configuration.auth_token {
token.to_string()
Ok(token.to_string())
} else {
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
};
anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API")
}
}
async fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: OpenAICompletionsResponse = client
.post(&self.configuration.completions_endpoint)
.post(
self.configuration
.completions_endpoint
.as_ref()
.context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?,
)
.bearer_auth(token)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
@@ -51,43 +78,231 @@ impl OpenAI {
"max_tokens": max_tokens,
"n": 1,
"top_p": self.configuration.top_p,
"top_k": self.configuration.top_k,
"presence_penalty": self.configuration.presence_penalty,
"frequency_penalty": self.configuration.frequency_penalty,
"temperature": self.configuration.temperature,
"echo": false,
"prompt": prompt
}))
.send()?
.json()?;
eprintln!("**********RECEIVED REQUEST********");
Ok(res.choices[0].text.clone())
.send().await?
.json().await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(mut choices) = res.choices {
Ok(std::mem::take(&mut choices[0].text))
} else {
anyhow::bail!("Uknown error while making request to OpenAI")
}
}
async fn get_chat(
&self,
messages: Vec<ChatMessage>,
max_tokens: usize,
) -> anyhow::Result<String> {
eprintln!(
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
messages
);
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: OpenAIChatResponse = client
.post(
self.configuration
.chat_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use completions")?,
)
.bearer_auth(token)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"model": self.configuration.model,
"max_tokens": max_tokens,
"n": 1,
"top_p": self.configuration.top_p,
"presence_penalty": self.configuration.presence_penalty,
"frequency_penalty": self.configuration.frequency_penalty,
"temperature": self.configuration.temperature,
"messages": messages
}))
.send()
.await?
.json()
.await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(choices) = res.choices {
Ok(choices[0].message.content.clone())
} else {
anyhow::bail!("Uknown error while making request to OpenAI")
}
}
async fn do_chat_completion(
&self,
prompt: &Prompt,
messages: Option<&Vec<ChatMessage>>,
max_tokens: usize,
) -> anyhow::Result<String> {
match messages {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, prompt);
self.get_chat(messages, max_tokens).await
}
None => {
self.get_completion(
&format_context_code(&prompt.context, &prompt.code),
max_tokens,
)
.await
}
}
}
}
#[async_trait::async_trait]
impl TransformerBackend for OpenAI {
#[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
eprintln!("--------------{:?}---------------", prompt);
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
let max_tokens = self.configuration.max_tokens.completion;
let messages = self
.configuration
.chat
.as_ref()
.and_then(|c| c.completion.as_ref());
let insert_text = self
.do_chat_completion(prompt, messages, max_tokens)
.await?;
Ok(DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
eprintln!("--------------{:?}---------------", prompt);
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
let generated_text =
self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
let max_tokens = self.configuration.max_tokens.generation;
let messages = self
.configuration
.chat
.as_ref()
.and_then(|c| c.generation.as_ref());
let generated_text = self
.do_chat_completion(prompt, messages, max_tokens)
.await?;
Ok(DoGenerateResponse { generated_text })
}
#[instrument(skip(self))]
fn do_generate_stream(
async fn do_generate_stream(
&self,
request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse> {
unimplemented!()
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn openai_completion_do_completion() -> anyhow::Result<()> {
let configuration: configuration::OpenAI = serde_json::from_value(json!({
"completions_endpoint": "https://api.openai.com/v1/completions",
"model": "gpt-3.5-turbo-instruct",
"auth_token_env_var_name": "OPENAI_API_KEY",
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_without_cursor();
let response = openai.do_completion(&prompt).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn openai_chat_do_completion() -> anyhow::Result<()> {
let configuration: configuration::OpenAI = serde_json::from_value(json!({
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
"model": "gpt-3.5-turbo",
"auth_token_env_var_name": "OPENAI_API_KEY",
"chat": {
"completion": [
{
"role": "system",
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
},
{
"role": "user",
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
}
],
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = openai.do_completion(&prompt).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn openai_completion_do_generate() -> anyhow::Result<()> {
let configuration: configuration::OpenAI = serde_json::from_value(json!({
"completions_endpoint": "https://api.openai.com/v1/completions",
"model": "gpt-3.5-turbo-instruct",
"auth_token_env_var_name": "OPENAI_API_KEY",
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_without_cursor();
let response = openai.do_generate(&prompt).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
#[tokio::test]
async fn openai_chat_do_generate() -> anyhow::Result<()> {
let configuration: configuration::OpenAI = serde_json::from_value(json!({
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
"model": "gpt-3.5-turbo",
"auth_token_env_var_name": "OPENAI_API_KEY",
"chat": {
"generation": [
{
"role": "system",
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
},
{
"role": "user",
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
}
]
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = openai.do_generate(&prompt).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
}

214
src/transformer_worker.rs Normal file
View File

@@ -0,0 +1,214 @@
use lsp_server::{Connection, Message, RequestId, Response};
use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit,
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use tokio::sync::oneshot;
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::GenerateStreamParams;
use crate::memory_backends::PromptForType;
use crate::memory_worker::{self, FilterRequest, PromptRequest};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
#[derive(Clone, Debug)]
pub struct CompletionRequest {
id: RequestId,
params: CompletionParams,
}
impl CompletionRequest {
pub fn new(id: RequestId, params: CompletionParams) -> Self {
Self { id, params }
}
}
#[derive(Clone, Debug)]
pub struct GenerateRequest {
id: RequestId,
params: GenerateParams,
}
impl GenerateRequest {
pub fn new(id: RequestId, params: GenerateParams) -> Self {
Self { id, params }
}
}
// The generate stream is not yet ready but we don't want to remove it
#[allow(dead_code)]
#[derive(Clone, Debug)]
pub struct GenerateStreamRequest {
id: RequestId,
params: GenerateStreamParams,
}
impl GenerateStreamRequest {
pub fn new(id: RequestId, params: GenerateStreamParams) -> Self {
Self { id, params }
}
}
#[derive(Clone)]
pub enum WorkerRequest {
Completion(CompletionRequest),
Generate(GenerateRequest),
GenerateStream(GenerateStreamRequest),
}
pub struct DoCompletionResponse {
pub insert_text: String,
}
pub struct DoGenerateResponse {
pub generated_text: String,
}
pub struct DoGenerateStreamResponse {
pub generated_text: String,
}
pub fn run(
transformer_backend: Box<dyn TransformerBackend + Send + Sync>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
connection: Arc<Connection>,
) -> anyhow::Result<()> {
let transformer_backend = Arc::new(transformer_backend);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()?;
loop {
let option_worker_request: Option<WorkerRequest> = {
let mut completion_request = last_worker_request.lock();
std::mem::take(&mut *completion_request)
};
if let Some(request) = option_worker_request {
let thread_connection = connection.clone();
let thread_transformer_backend = transformer_backend.clone();
let thread_memory_backend_tx = memory_backend_tx.clone();
runtime.spawn(async move {
let response = match request {
WorkerRequest::Completion(request) => match do_completion(
thread_transformer_backend,
thread_memory_backend_tx,
&request,
)
.await
{
Ok(r) => r,
Err(e) => Response {
id: request.id,
result: None,
error: Some(e.to_response_error(-32603)),
},
},
WorkerRequest::Generate(request) => match do_generate(
thread_transformer_backend,
thread_memory_backend_tx,
&request,
)
.await
{
Ok(r) => r,
Err(e) => Response {
id: request.id,
result: None,
error: Some(e.to_response_error(-32603)),
},
},
WorkerRequest::GenerateStream(_) => {
panic!("Streaming is not supported yet")
}
};
thread_connection
.sender
.send(Message::Response(response))
.expect("Error sending message");
});
}
thread::sleep(std::time::Duration::from_millis(5));
}
}
async fn do_completion(
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
request: &CompletionRequest,
) -> anyhow::Result<Response> {
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
request.params.text_document_position.clone(),
PromptForType::Completion,
tx,
)))?;
let prompt = rx.await?;
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::FilterText(
FilterRequest::new(request.params.text_document_position.clone(), tx),
))?;
let filter_text = rx.await?;
let response = transformer_backend.do_completion(&prompt).await?;
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
),
response.insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {}", response.insert_text),
filter_text: Some(filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(result).unwrap();
Ok(Response {
id: request.id.clone(),
result: Some(result),
error: None,
})
}
async fn do_generate(
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
request: &GenerateRequest,
) -> anyhow::Result<Response> {
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
request.params.text_document_position.clone(),
PromptForType::Completion,
tx,
)))?;
let prompt = rx.await?;
let response = transformer_backend.do_generate(&prompt).await?;
let result = GenerateResult {
generated_text: response.generated_text,
};
let result = serde_json::to_value(result).unwrap();
Ok(Response {
id: request.id.clone(),
result: Some(result),
error: None,
})
}

View File

@@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
tokens * 4
}
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec<ChatMessage> {
messages
.iter()
.map(|m| ChatMessage {
role: m.role.to_owned(),
content: m
.content
.replace("{context}", &prompt.context)
.replace("{code}", &prompt.code),
.replace("{CONTEXT}", &prompt.context)
.replace("{CODE}", &prompt.code),
})
.collect()
}
pub fn format_context_code(context: &str, code: &str) -> String {
format!("{context}\n\n{code}")
}

View File

@@ -1,191 +0,0 @@
use lsp_server::{Connection, Message, RequestId, Response};
use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit,
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use tracing::instrument;
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::GenerateStreamParams;
use crate::memory_backends::{MemoryBackend, PromptForType};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
#[derive(Clone, Debug)]
pub struct CompletionRequest {
id: RequestId,
params: CompletionParams,
}
impl CompletionRequest {
pub fn new(id: RequestId, params: CompletionParams) -> Self {
Self { id, params }
}
}
#[derive(Clone, Debug)]
pub struct GenerateRequest {
id: RequestId,
params: GenerateParams,
}
impl GenerateRequest {
pub fn new(id: RequestId, params: GenerateParams) -> Self {
Self { id, params }
}
}
// The generate stream is not yet ready but we don't want to remove it
#[allow(dead_code)]
#[derive(Clone, Debug)]
pub struct GenerateStreamRequest {
id: RequestId,
params: GenerateStreamParams,
}
impl GenerateStreamRequest {
pub fn new(id: RequestId, params: GenerateStreamParams) -> Self {
Self { id, params }
}
}
#[derive(Clone)]
pub enum WorkerRequest {
Completion(CompletionRequest),
Generate(GenerateRequest),
GenerateStream(GenerateStreamRequest),
}
pub struct DoCompletionResponse {
pub insert_text: String,
}
pub struct DoGenerateResponse {
pub generated_text: String,
}
pub struct DoGenerateStreamResponse {
pub generated_text: String,
}
pub struct Worker {
transformer_backend: Box<dyn TransformerBackend>,
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
connection: Arc<Connection>,
}
impl Worker {
pub fn new(
transformer_backend: Box<dyn TransformerBackend>,
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
connection: Arc<Connection>,
) -> Self {
Self {
transformer_backend,
memory_backend,
last_worker_request,
connection,
}
}
#[instrument(skip(self))]
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
let prompt = self.memory_backend.lock().build_prompt(
&request.params.text_document_position,
PromptForType::Completion,
)?;
let filter_text = self
.memory_backend
.lock()
.get_filter_text(&request.params.text_document_position)?;
let response = self.transformer_backend.do_completion(&prompt)?;
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
),
response.insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {}", response.insert_text),
filter_text: Some(filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
Ok(Response {
id: request.id.clone(),
result: Some(result),
error: None,
})
}
#[instrument(skip(self))]
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
let prompt = self.memory_backend.lock().build_prompt(
&request.params.text_document_position,
PromptForType::Generate,
)?;
let response = self.transformer_backend.do_generate(&prompt)?;
let result = GenerateResult {
generated_text: response.generated_text,
};
let result = serde_json::to_value(&result).unwrap();
Ok(Response {
id: request.id.clone(),
result: Some(result),
error: None,
})
}
pub fn run(self) {
loop {
let option_worker_request: Option<WorkerRequest> = {
let mut completion_request = self.last_worker_request.lock();
std::mem::take(&mut *completion_request)
};
if let Some(request) = option_worker_request {
let response = match request {
WorkerRequest::Completion(request) => match self.do_completion(&request) {
Ok(r) => r,
Err(e) => Response {
id: request.id,
result: None,
error: Some(e.to_response_error(-32603)),
},
},
WorkerRequest::Generate(request) => match self.do_generate(&request) {
Ok(r) => r,
Err(e) => Response {
id: request.id,
result: None,
error: Some(e.to_response_error(-32603)),
},
},
WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"),
};
self.connection
.sender
.send(Message::Response(response))
.expect("Error sending message");
}
thread::sleep(std::time::Duration::from_millis(5));
}
}
}