mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-20 16:04:21 +01:00
Added PostgresML memory backend
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "submodules/postgresml"]
|
||||
path = submodules/postgresml
|
||||
url = https://github.com/postgresml/postgresml
|
||||
1572
Cargo.lock
generated
1572
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,9 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
|
||||
reqwest = { version = "0.11.25", features = ["blocking", "json"] }
|
||||
ignore = "0.4.22"
|
||||
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
|
||||
tokio = { version = "1.36.0", features = ["rt-multi-thread"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -13,7 +13,7 @@ pub type Kwargs = HashMap<String, Value>;
|
||||
|
||||
pub enum ValidMemoryBackend {
|
||||
FileStore,
|
||||
PostgresML,
|
||||
PostgresML(PostgresML),
|
||||
}
|
||||
|
||||
pub enum ValidTransformerBackend {
|
||||
@@ -57,15 +57,22 @@ impl Default for MaxTokens {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct PostgresML {
|
||||
pub database_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ValidMemoryConfiguration {
|
||||
file_store: Option<Value>,
|
||||
postgresml: Option<PostgresML>,
|
||||
}
|
||||
|
||||
impl Default for ValidMemoryConfiguration {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
file_store: Some(json!({})),
|
||||
postgresml: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -188,13 +195,22 @@ struct ValidConfiguration {
|
||||
transformer: ValidTransformerConfiguration,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct ValidClientParams {
|
||||
#[serde(alias = "rootURI")]
|
||||
root_uri: Option<String>,
|
||||
workspace_folders: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Configuration {
|
||||
valid_config: ValidConfiguration,
|
||||
client_params: ValidClientParams,
|
||||
}
|
||||
|
||||
impl Configuration {
|
||||
pub fn new(mut args: Value) -> Result<Self> {
|
||||
eprintln!("\n\n{}\n\n", args.to_string());
|
||||
let configuration_args = args
|
||||
.as_object_mut()
|
||||
.context("Server configuration must be a JSON object")?
|
||||
@@ -203,14 +219,18 @@ impl Configuration {
|
||||
Some(configuration_args) => serde_json::from_value(configuration_args)?,
|
||||
None => ValidConfiguration::default(),
|
||||
};
|
||||
let client_params: ValidClientParams = serde_json::from_value(args)?;
|
||||
Ok(Self {
|
||||
valid_config: valid_args,
|
||||
client_params,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_memory_backend(&self) -> Result<ValidMemoryBackend> {
|
||||
if self.valid_config.memory.file_store.is_some() {
|
||||
Ok(ValidMemoryBackend::FileStore)
|
||||
} else if let Some(postgresml) = &self.valid_config.memory.postgresml {
|
||||
Ok(ValidMemoryBackend::PostgresML(postgresml.to_owned()))
|
||||
} else {
|
||||
anyhow::bail!("Invalid memory configuration")
|
||||
}
|
||||
@@ -230,7 +250,7 @@ impl Configuration {
|
||||
// Helpers for the Memory Backend /////
|
||||
///////////////////////////////////////
|
||||
|
||||
pub fn get_maximum_context_length(&self) -> Result<usize> {
|
||||
pub fn get_max_context_length(&self) -> Result<usize> {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
Ok(model_gguf
|
||||
.kwargs
|
||||
|
||||
@@ -4,7 +4,7 @@ use ropey::Rope;
|
||||
use std::collections::HashMap;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
|
||||
use crate::{configuration::Configuration, utils::tokens_to_estimated_characters};
|
||||
|
||||
use super::{MemoryBackend, Prompt, PromptForType};
|
||||
|
||||
@@ -20,28 +20,35 @@ impl FileStore {
|
||||
file_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
pub fn get_characters_around_position(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
characters: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
Ok(rope
|
||||
.get_line(position.position.line as usize)
|
||||
.context("Error getting filter_text")?
|
||||
.to_string())
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (characters - (cursor_index - start)));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end)
|
||||
.context("Error getting rope slice")?;
|
||||
Ok(rope_slice.to_string())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
pub fn build_code(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
max_context_length: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut rope = self
|
||||
.file_map
|
||||
.get(position.text_document.uri.as_str())
|
||||
@@ -66,13 +73,11 @@ impl MemoryBackend for FileStore {
|
||||
|
||||
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
|
||||
// and the user has not enabled chat
|
||||
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
|
||||
Ok(match (is_chat_enabled, self.configuration.get_fim()?) {
|
||||
r @ (true, _) | r @ (false, Some(_))
|
||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||
{
|
||||
let max_length = characters_to_estimated_tokens(
|
||||
self.configuration.get_maximum_context_length()?,
|
||||
);
|
||||
let max_length = tokens_to_estimated_characters(max_context_length);
|
||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
@@ -103,16 +108,42 @@ impl MemoryBackend for FileStore {
|
||||
}
|
||||
_ => {
|
||||
let start = cursor_index
|
||||
.checked_sub(characters_to_estimated_tokens(
|
||||
self.configuration.get_maximum_context_length()?,
|
||||
))
|
||||
.checked_sub(tokens_to_estimated_characters(max_context_length))
|
||||
.unwrap_or(0);
|
||||
let rope_slice = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
rope_slice.to_string()
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
Ok(rope
|
||||
.get_line(position.position.line as usize)
|
||||
.context("Error getting filter_text")?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
&mut self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let code = self.build_code(
|
||||
position,
|
||||
prompt_for_type,
|
||||
self.configuration.get_max_context_length()?,
|
||||
)?;
|
||||
Ok(Prompt::new("".to_string(), code))
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use lsp_types::{
|
||||
use crate::configuration::{Configuration, ValidMemoryBackend};
|
||||
|
||||
pub mod file_store;
|
||||
mod postgresml;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Prompt {
|
||||
@@ -33,7 +34,7 @@ pub trait MemoryBackend {
|
||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn build_prompt(
|
||||
&self,
|
||||
&mut self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt>;
|
||||
@@ -48,7 +49,9 @@ impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||
ValidMemoryBackend::FileStore => {
|
||||
Ok(Box::new(file_store::FileStore::new(configuration)))
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new(
|
||||
postgresml::PostgresML::new(postgresml_config, configuration)?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
221
src/memory_backends/postgresml/mod.rs
Normal file
221
src/memory_backends/postgresml/mod.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Context;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use pgml::{Collection, Pipeline};
|
||||
use serde_json::json;
|
||||
use tokio::runtime::Runtime;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
configuration::{self, Configuration},
|
||||
utils::tokens_to_estimated_characters,
|
||||
};
|
||||
|
||||
use super::{file_store::FileStore, MemoryBackend, Prompt, PromptForType};
|
||||
|
||||
pub struct PostgresML {
|
||||
configuration: Configuration,
|
||||
file_store: FileStore,
|
||||
collection: Collection,
|
||||
pipeline: Pipeline,
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
impl PostgresML {
|
||||
pub fn new(
|
||||
postgresml_config: configuration::PostgresML,
|
||||
configuration: Configuration,
|
||||
) -> anyhow::Result<Self> {
|
||||
let file_store = FileStore::new(configuration.clone());
|
||||
let database_url = if let Some(database_url) = postgresml_config.database_url {
|
||||
database_url
|
||||
} else {
|
||||
std::env::var("PGML_DATABASE_URL")?
|
||||
};
|
||||
// TODO: Think on the naming of the collection
|
||||
// Maybe filter on metadata or I'm not sure
|
||||
let collection = Collection::new("test-lsp-ai", Some(database_url))?;
|
||||
// TODO: Review the pipeline
|
||||
let pipeline = Pipeline::new(
|
||||
"v1",
|
||||
Some(
|
||||
json!({
|
||||
"text": {
|
||||
"splitter": {
|
||||
"model": "recursive_character",
|
||||
"parameters": {
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 40
|
||||
}
|
||||
},
|
||||
"semantic_search": {
|
||||
"model": "intfloat/e5-small",
|
||||
}
|
||||
}
|
||||
})
|
||||
.into(),
|
||||
),
|
||||
)?;
|
||||
// Create our own runtime
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(2)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
// Add the collection to the pipeline
|
||||
let mut task_collection = collection.clone();
|
||||
let mut task_pipeline = pipeline.clone();
|
||||
runtime.spawn(async move {
|
||||
task_collection
|
||||
.add_pipeline(&mut task_pipeline)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
});
|
||||
// Need to crawl the root path and or workspace folders
|
||||
// Or set some kind of did crawl for it
|
||||
Ok(Self {
|
||||
configuration,
|
||||
file_store,
|
||||
collection,
|
||||
pipeline,
|
||||
runtime,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryBackend for PostgresML {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn build_prompt(
|
||||
&mut self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_for_type: PromptForType,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
// This is blocking, but this is ok as we only query for it from the worker when we are actually doing a transform
|
||||
let query = self
|
||||
.file_store
|
||||
.get_characters_around_position(position, 512)?;
|
||||
let res = self.runtime.block_on(
|
||||
self.collection.vector_search(
|
||||
json!({
|
||||
"query": {
|
||||
"fields": {
|
||||
"text": {
|
||||
"query": query
|
||||
}
|
||||
},
|
||||
},
|
||||
"limit": 5
|
||||
})
|
||||
.into(),
|
||||
&mut self.pipeline,
|
||||
),
|
||||
)?;
|
||||
let context = res
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
c["chunk"]
|
||||
.as_str()
|
||||
.map(|t| t.to_owned())
|
||||
.context("PGML - Error getting chunk from vector search")
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<String>>>()?
|
||||
.join("\n\n");
|
||||
let code = self.file_store.build_code(position, prompt_for_type, 512)?;
|
||||
let max_characters =
|
||||
tokens_to_estimated_characters(self.configuration.get_max_context_length()?);
|
||||
let context: String = context
|
||||
.chars()
|
||||
.take(max_characters - code.chars().count())
|
||||
.collect();
|
||||
eprintln!("CONTEXT: {}", context);
|
||||
eprintln!("CODE: #########{}######", code);
|
||||
Ok(Prompt::new(context, code))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let text = params.text_document.text.clone();
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
let mut task_collection = self.collection.clone();
|
||||
self.runtime.spawn(async move {
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
});
|
||||
self.file_store.opened_text_document(params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
let text = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("Error reading path: {}", path))?;
|
||||
let mut task_collection = self.collection.clone();
|
||||
self.runtime.spawn(async move {
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
});
|
||||
self.file_store.changed_text_document(params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
let mut task_collection = self.collection.clone();
|
||||
let task_params = params.clone();
|
||||
self.runtime.spawn(async move {
|
||||
for file in task_params.files {
|
||||
task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"id": file.old_uri
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error deleting file");
|
||||
let text =
|
||||
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": file.new_uri,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
});
|
||||
self.file_store.renamed_file(params)
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,7 @@ impl OpenAI {
|
||||
}
|
||||
|
||||
fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
|
||||
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
@@ -66,15 +67,18 @@ impl OpenAI {
|
||||
impl TransformerBackend for OpenAI {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
let insert_text =
|
||||
self.get_completion(&prompt.code, self.configuration.max_tokens.completion)?;
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
||||
let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
|
||||
let generated_text =
|
||||
self.get_completion(&prompt.code, self.configuration.max_tokens.completion)?;
|
||||
self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ impl ToResponseError for anyhow::Error {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn characters_to_estimated_tokens(characters: usize) -> usize {
|
||||
characters * 4
|
||||
pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
|
||||
tokens * 4
|
||||
}
|
||||
|
||||
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
|
||||
|
||||
1
submodules/postgresml
Submodule
1
submodules/postgresml
Submodule
Submodule submodules/postgresml added at 0842673804
Reference in New Issue
Block a user