Added PostgresML memory backend

This commit is contained in:
Silas Marvin
2024-03-10 15:29:47 -07:00
parent 2081c40759
commit 047215fc31
10 changed files with 1809 additions and 107 deletions

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "submodules/postgresml"]
path = submodules/postgresml
url = https://github.com/postgresml/postgresml

1572
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,9 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
reqwest = { version = "0.11.25", features = ["blocking", "json"] }
ignore = "0.4.22"
pgml = { path = "submodules/postgresml/pgml-sdks/pgml" }
tokio = { version = "1.36.0", features = ["rt-multi-thread"] }
[features]
default = []

View File

@@ -13,7 +13,7 @@ pub type Kwargs = HashMap<String, Value>;
pub enum ValidMemoryBackend {
FileStore,
PostgresML,
PostgresML(PostgresML),
}
pub enum ValidTransformerBackend {
@@ -57,15 +57,22 @@ impl Default for MaxTokens {
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct PostgresML {
pub database_url: Option<String>,
}
#[derive(Clone, Debug, Deserialize)]
struct ValidMemoryConfiguration {
file_store: Option<Value>,
postgresml: Option<PostgresML>,
}
impl Default for ValidMemoryConfiguration {
fn default() -> Self {
Self {
file_store: Some(json!({})),
postgresml: None,
}
}
}
@@ -188,13 +195,22 @@ struct ValidConfiguration {
transformer: ValidTransformerConfiguration,
}
#[derive(Clone, Debug, Deserialize, Default)]
pub struct ValidClientParams {
#[serde(alias = "rootURI")]
root_uri: Option<String>,
workspace_folders: Option<Vec<String>>,
}
#[derive(Clone, Debug)]
pub struct Configuration {
valid_config: ValidConfiguration,
client_params: ValidClientParams,
}
impl Configuration {
pub fn new(mut args: Value) -> Result<Self> {
eprintln!("\n\n{}\n\n", args.to_string());
let configuration_args = args
.as_object_mut()
.context("Server configuration must be a JSON object")?
@@ -203,14 +219,18 @@ impl Configuration {
Some(configuration_args) => serde_json::from_value(configuration_args)?,
None => ValidConfiguration::default(),
};
let client_params: ValidClientParams = serde_json::from_value(args)?;
Ok(Self {
valid_config: valid_args,
client_params,
})
}
pub fn get_memory_backend(&self) -> Result<ValidMemoryBackend> {
if self.valid_config.memory.file_store.is_some() {
Ok(ValidMemoryBackend::FileStore)
} else if let Some(postgresml) = &self.valid_config.memory.postgresml {
Ok(ValidMemoryBackend::PostgresML(postgresml.to_owned()))
} else {
anyhow::bail!("Invalid memory configuration")
}
@@ -230,7 +250,7 @@ impl Configuration {
// Helpers for the Memory Backend /////
///////////////////////////////////////
pub fn get_maximum_context_length(&self) -> Result<usize> {
pub fn get_max_context_length(&self) -> Result<usize> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
Ok(model_gguf
.kwargs

View File

@@ -4,7 +4,7 @@ use ropey::Rope;
use std::collections::HashMap;
use tracing::instrument;
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
use crate::{configuration::Configuration, utils::tokens_to_estimated_characters};
use super::{MemoryBackend, Prompt, PromptForType};
@@ -20,28 +20,35 @@ impl FileStore {
file_map: HashMap::new(),
}
}
}
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
pub fn get_characters_around_position(
&self,
position: &TextDocumentPositionParams,
characters: usize,
) -> anyhow::Result<String> {
let rope = self
.file_map
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
Ok(rope
.get_line(position.position.line as usize)
.context("Error getting filter_text")?
.to_string())
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
let end = rope
.len_chars()
.min(cursor_index + (characters - (cursor_index - start)));
let rope_slice = rope
.get_slice(start..end)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
#[instrument(skip(self))]
fn build_prompt(
pub fn build_code(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
max_context_length: usize,
) -> anyhow::Result<String> {
let mut rope = self
.file_map
.get(position.text_document.uri.as_str())
@@ -66,13 +73,11 @@ impl MemoryBackend for FileStore {
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
// and the user has not enabled chat
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
Ok(match (is_chat_enabled, self.configuration.get_fim()?) {
r @ (true, _) | r @ (false, Some(_))
if is_chat_enabled || rope.len_chars() != cursor_index =>
{
let max_length = characters_to_estimated_tokens(
self.configuration.get_maximum_context_length()?,
);
let max_length = tokens_to_estimated_characters(max_context_length);
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let end = rope
.len_chars()
@@ -103,16 +108,42 @@ impl MemoryBackend for FileStore {
}
_ => {
let start = cursor_index
.checked_sub(characters_to_estimated_tokens(
self.configuration.get_maximum_context_length()?,
))
.checked_sub(tokens_to_estimated_characters(max_context_length))
.unwrap_or(0);
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
rope_slice.to_string()
}
};
})
}
}
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
.file_map
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
Ok(rope
.get_line(position.position.line as usize)
.context("Error getting filter_text")?
.to_string())
}
#[instrument(skip(self))]
fn build_prompt(
&mut self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
let code = self.build_code(
position,
prompt_for_type,
self.configuration.get_max_context_length()?,
)?;
Ok(Prompt::new("".to_string(), code))
}

View File

@@ -6,6 +6,7 @@ use lsp_types::{
use crate::configuration::{Configuration, ValidMemoryBackend};
pub mod file_store;
mod postgresml;
#[derive(Debug)]
pub struct Prompt {
@@ -33,7 +34,7 @@ pub trait MemoryBackend {
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
fn build_prompt(
&self,
&mut self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt>;
@@ -48,7 +49,9 @@ impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
ValidMemoryBackend::FileStore => {
Ok(Box::new(file_store::FileStore::new(configuration)))
}
_ => unimplemented!(),
ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new(
postgresml::PostgresML::new(postgresml_config, configuration)?,
)),
}
}
}

View File

@@ -0,0 +1,221 @@
use std::path::Path;
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use pgml::{Collection, Pipeline};
use serde_json::json;
use tokio::runtime::Runtime;
use tracing::instrument;
use crate::{
configuration::{self, Configuration},
utils::tokens_to_estimated_characters,
};
use super::{file_store::FileStore, MemoryBackend, Prompt, PromptForType};
pub struct PostgresML {
configuration: Configuration,
file_store: FileStore,
collection: Collection,
pipeline: Pipeline,
runtime: Runtime,
}
impl PostgresML {
pub fn new(
postgresml_config: configuration::PostgresML,
configuration: Configuration,
) -> anyhow::Result<Self> {
let file_store = FileStore::new(configuration.clone());
let database_url = if let Some(database_url) = postgresml_config.database_url {
database_url
} else {
std::env::var("PGML_DATABASE_URL")?
};
// TODO: Think on the naming of the collection
// Maybe filter on metadata or I'm not sure
let collection = Collection::new("test-lsp-ai", Some(database_url))?;
// TODO: Review the pipeline
let pipeline = Pipeline::new(
"v1",
Some(
json!({
"text": {
"splitter": {
"model": "recursive_character",
"parameters": {
"chunk_size": 512,
"chunk_overlap": 40
}
},
"semantic_search": {
"model": "intfloat/e5-small",
}
}
})
.into(),
),
)?;
// Create our own runtime
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()?;
// Add the collection to the pipeline
let mut task_collection = collection.clone();
let mut task_pipeline = pipeline.clone();
runtime.spawn(async move {
task_collection
.add_pipeline(&mut task_pipeline)
.await
.expect("PGML - Error adding pipeline to collection");
});
// Need to crawl the root path and or workspace folders
// Or set some kind of did crawl for it
Ok(Self {
configuration,
file_store,
collection,
pipeline,
runtime,
})
}
}
impl MemoryBackend for PostgresML {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
self.file_store.get_filter_text(position)
}
#[instrument(skip(self))]
fn build_prompt(
&mut self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
// This is blocking, but this is ok as we only query for it from the worker when we are actually doing a transform
let query = self
.file_store
.get_characters_around_position(position, 512)?;
let res = self.runtime.block_on(
self.collection.vector_search(
json!({
"query": {
"fields": {
"text": {
"query": query
}
},
},
"limit": 5
})
.into(),
&mut self.pipeline,
),
)?;
let context = res
.into_iter()
.map(|c| {
c["chunk"]
.as_str()
.map(|t| t.to_owned())
.context("PGML - Error getting chunk from vector search")
})
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
let code = self.file_store.build_code(position, prompt_for_type, 512)?;
let max_characters =
tokens_to_estimated_characters(self.configuration.get_max_context_length()?);
let context: String = context
.chars()
.take(max_characters - code.chars().count())
.collect();
eprintln!("CONTEXT: {}", context);
eprintln!("CODE: #########{}######", code);
Ok(Prompt::new(context, code))
}
#[instrument(skip(self))]
fn opened_text_document(
&mut self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let text = params.text_document.text.clone();
let path = params.text_document.uri.path().to_owned();
let mut task_collection = self.collection.clone();
self.runtime.spawn(async move {
task_collection
.upsert_documents(
vec![json!({
"id": path,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
});
self.file_store.opened_text_document(params)
}
#[instrument(skip(self))]
fn changed_text_document(
&mut self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let path = params.text_document.uri.path().to_owned();
let text = std::fs::read_to_string(&path)
.with_context(|| format!("Error reading path: {}", path))?;
let mut task_collection = self.collection.clone();
self.runtime.spawn(async move {
task_collection
.upsert_documents(
vec![json!({
"id": path,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
});
self.file_store.changed_text_document(params)
}
#[instrument(skip(self))]
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
let mut task_collection = self.collection.clone();
let task_params = params.clone();
self.runtime.spawn(async move {
for file in task_params.files {
task_collection
.delete_documents(
json!({
"id": file.old_uri
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text =
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
task_collection
.upsert_documents(
vec![json!({
"id": file.new_uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
}
});
self.file_store.renamed_file(params)
}
}

View File

@@ -32,6 +32,7 @@ impl OpenAI {
}
fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
let client = reqwest::blocking::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
@@ -66,15 +67,18 @@ impl OpenAI {
impl TransformerBackend for OpenAI {
#[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
let insert_text =
self.get_completion(&prompt.code, self.configuration.max_tokens.completion)?;
eprintln!("--------------{:?}---------------", prompt);
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
let insert_text = self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
Ok(DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
eprintln!("--------------{:?}---------------", prompt);
let prompt = format!("{} \n\n {}", prompt.context, prompt.code);
let generated_text =
self.get_completion(&prompt.code, self.configuration.max_tokens.completion)?;
self.get_completion(&prompt, self.configuration.max_tokens.completion)?;
Ok(DoGenerateResponse { generated_text })
}

View File

@@ -16,8 +16,8 @@ impl ToResponseError for anyhow::Error {
}
}
pub fn characters_to_estimated_tokens(characters: usize) -> usize {
characters * 4
pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
tokens * 4
}
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {

1
submodules/postgresml Submodule

Submodule submodules/postgresml added at 0842673804