Getting closer

This commit is contained in:
Silas Marvin
2024-06-19 12:14:56 -07:00
parent 3e8c99b237
commit 9166aaf4b6
7 changed files with 186 additions and 35 deletions

View File

@@ -156,6 +156,8 @@ pub struct PostgresML {
pub crawl: Option<Crawl>, pub crawl: Option<Crawl>,
#[serde(default)] #[serde(default)]
pub splitter: ValidSplitter, pub splitter: ValidSplitter,
pub embedding_model: Option<String>,
pub embedding_model_parameters: Option<Value>,
} }
#[derive(Clone, Debug, Deserialize, Default)] #[derive(Clone, Debug, Deserialize, Default)]

View File

@@ -1,6 +1,6 @@
use anyhow::Context; use anyhow::Context;
use indexmap::IndexSet; use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams; use lsp_types::{Position, TextDocumentPositionParams};
use parking_lot::Mutex; use parking_lot::Mutex;
use ropey::Rope; use ropey::Rope;
use serde_json::Value; use serde_json::Value;
@@ -154,6 +154,7 @@ impl FileStore {
&self, &self,
position: &TextDocumentPositionParams, position: &TextDocumentPositionParams,
characters: usize, characters: usize,
pull_from_multiple_files: bool,
) -> anyhow::Result<(Rope, usize)> { ) -> anyhow::Result<(Rope, usize)> {
// Get the rope and set our initial cursor index // Get the rope and set our initial cursor index
let current_document_uri = position.text_document.uri.to_string(); let current_document_uri = position.text_document.uri.to_string();
@@ -174,7 +175,7 @@ impl FileStore {
.filter(|f| **f != current_document_uri) .filter(|f| **f != current_document_uri)
{ {
let needed = characters.saturating_sub(rope.len_chars() + 1); let needed = characters.saturating_sub(rope.len_chars() + 1);
if needed == 0 { if needed == 0 || !pull_from_multiple_files {
break; break;
} }
let file_map = self.file_map.lock(); let file_map = self.file_map.lock();
@@ -220,9 +221,13 @@ impl FileStore {
position: &TextDocumentPositionParams, position: &TextDocumentPositionParams,
prompt_type: PromptType, prompt_type: PromptType,
params: MemoryRunParams, params: MemoryRunParams,
pull_from_multiple_files: bool,
) -> anyhow::Result<Prompt> { ) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) = let (mut rope, cursor_index) = self.get_rope_for_position(
self.get_rope_for_position(position, params.max_context_length)?; position,
params.max_context_length,
pull_from_multiple_files,
)?;
Ok(match prompt_type { Ok(match prompt_type {
PromptType::ContextAndCode => { PromptType::ContextAndCode => {
@@ -277,6 +282,20 @@ impl FileStore {
pub fn contains_file(&self, uri: &str) -> bool { pub fn contains_file(&self, uri: &str) -> bool {
self.file_map.lock().contains_key(uri) self.file_map.lock().contains_key(uri)
} }
pub fn position_to_byte(&self, position: &TextDocumentPositionParams) -> anyhow::Result<usize> {
let file_map = self.file_map.lock();
let uri = position.text_document.uri.to_string();
let file = file_map
.get(&uri)
.with_context(|| format!("trying to get file that does not exist {uri}"))?;
let line_char_index = file
.rope
.try_line_to_char(position.position.line as usize)?;
Ok(file
.rope
.try_char_to_byte(line_char_index + position.position.character as usize)?)
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@@ -307,7 +326,7 @@ impl MemoryBackend for FileStore {
params: &Value, params: &Value,
) -> anyhow::Result<Prompt> { ) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?; let params: MemoryRunParams = params.try_into()?;
self.build_code(position, prompt_type, params) self.build_code(position, prompt_type, params, true)
} }
#[instrument(skip(self))] #[instrument(skip(self))]

View File

@@ -29,11 +29,30 @@ use super::{
const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000; const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000;
fn chunk_to_document(uri: &str, chunk: Chunk) -> Value { fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> String {
let path = match root_uri {
Some(root_uri) => {
if uri.starts_with(root_uri) {
&uri[root_uri.chars().count()..]
} else {
uri
}
}
None => uri,
};
format!(
r#"--{path}--
{}
"#,
chunk.text
)
}
fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value {
json!({ json!({
"id": chunk_to_id(uri, &chunk), "id": chunk_to_id(uri, &chunk),
"uri": uri, "uri": uri,
"text": chunk.text, "text": format_chunk_chunk(uri, &chunk, root_uri),
"range": chunk.range "range": chunk.range
}) })
} }
@@ -43,6 +62,7 @@ async fn split_and_upsert_file(
collection: &mut Collection, collection: &mut Collection,
file_store: Arc<FileStore>, file_store: Arc<FileStore>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>, splitter: Arc<Box<dyn Splitter + Send + Sync>>,
root_uri: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// We need to make sure we don't hold the file_store lock while performing a network call // We need to make sure we don't hold the file_store lock while performing a network call
let chunks = { let chunks = {
@@ -55,7 +75,7 @@ async fn split_and_upsert_file(
let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?; let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?;
let documents = chunks let documents = chunks
.into_iter() .into_iter()
.map(|chunk| chunk_to_document(uri, chunk).into()) .map(|chunk| chunk_to_document(uri, chunk, root_uri).into())
.collect(); .collect();
collection collection
.upsert_documents(documents, None) .upsert_documents(documents, None)
@@ -65,7 +85,7 @@ async fn split_and_upsert_file(
#[derive(Clone)] #[derive(Clone)]
pub struct PostgresML { pub struct PostgresML {
_config: Config, config: Config,
file_store: Arc<FileStore>, file_store: Arc<FileStore>,
collection: Collection, collection: Collection,
pipeline: Pipeline, pipeline: Pipeline,
@@ -100,21 +120,19 @@ impl PostgresML {
std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")? std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")?
}; };
let collection_name = match configuration.client_params.root_uri.clone() { // Build our pipeline schema
Some(root_uri) => format!("{:x}", md5::compute(root_uri.as_bytes())), let pipeline = match postgresml_config.embedding_model {
None => { Some(embedding_model) => {
warn!("no root_uri provided in server configuration - generating random string for collection name"); json!({
rand::thread_rng() "text": {
.sample_iter(&Alphanumeric) "semantic_search": {
.take(21) "model": embedding_model,
.map(char::from) "parameters": postgresml_config.embedding_model_parameters
.collect() }
}
})
} }
}; None => {
let mut collection = Collection::new(&collection_name, Some(database_url))?;
let mut pipeline = Pipeline::new(
"v1",
Some(
json!({ json!({
"text": { "text": {
"semantic_search": { "semantic_search": {
@@ -125,16 +143,36 @@ impl PostgresML {
} }
} }
}) })
.into(), }
};
// When building the collection name we include the Pipeline schema
// If the user changes the Pipeline schema, it will take affect without them having to delete the old files
let collection_name = match configuration.client_params.root_uri.clone() {
Some(root_uri) => format!(
"{:x}",
md5::compute(
format!("{root_uri}_{}", serde_json::to_string(&pipeline)?).as_bytes()
)
), ),
)?; None => {
warn!("no root_uri provided in server configuration - generating random string for collection name");
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(21)
.map(char::from)
.collect()
}
};
let mut collection = Collection::new(&collection_name, Some(database_url))?;
let mut pipeline = Pipeline::new("v1", Some(pipeline.into()))?;
// Add the Pipeline to the Collection // Add the Pipeline to the Collection
TOKIO_RUNTIME.block_on(async { TOKIO_RUNTIME.block_on(async {
collection collection
.add_pipeline(&mut pipeline) .add_pipeline(&mut pipeline)
.await .await
.context("PGML - Error adding pipeline to collection") .context("PGML - error adding pipeline to collection")
})?; })?;
// Setup up a debouncer for changed text documents // Setup up a debouncer for changed text documents
@@ -142,6 +180,7 @@ impl PostgresML {
let mut task_collection = collection.clone(); let mut task_collection = collection.clone();
let task_file_store = file_store.clone(); let task_file_store = file_store.clone();
let task_splitter = splitter.clone(); let task_splitter = splitter.clone();
let task_root_uri = configuration.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move { TOKIO_RUNTIME.spawn(async move {
let duration = Duration::from_millis(500); let duration = Duration::from_millis(500);
let mut file_uris = Vec::new(); let mut file_uris = Vec::new();
@@ -218,7 +257,9 @@ impl PostgresML {
.map(|(chunks, uri)| { .map(|(chunks, uri)| {
chunks chunks
.into_iter() .into_iter()
.map(|chunk| chunk_to_document(&uri, chunk)) .map(|chunk| {
chunk_to_document(&uri, chunk, task_root_uri.as_deref())
})
.collect::<Vec<Value>>() .collect::<Vec<Value>>()
}) })
.flatten() .flatten()
@@ -227,7 +268,7 @@ impl PostgresML {
if let Err(e) = task_collection if let Err(e) = task_collection
.upsert_documents(documents, None) .upsert_documents(documents, None)
.await .await
.context("PGML - Error upserting changed files") .context("PGML - error upserting changed files")
{ {
error!("{e:?}"); error!("{e:?}");
continue; continue;
@@ -239,7 +280,7 @@ impl PostgresML {
}); });
let s = Self { let s = Self {
_config: configuration, config: configuration,
file_store, file_store,
collection, collection,
pipeline, pipeline,
@@ -317,7 +358,14 @@ impl PostgresML {
.splitter .splitter
.split_file_contents(&uri, &contents) .split_file_contents(&uri, &contents)
.into_iter() .into_iter()
.map(|chunk| chunk_to_document(&uri, chunk).into()) .map(|chunk| {
chunk_to_document(
&uri,
chunk,
self.config.client_params.root_uri.as_deref(),
)
.into()
})
.collect(); .collect();
chunks_to_upsert.extend(chunks); chunks_to_upsert.extend(chunks);
// If we have over 10 mega bytes of chunks do the upsert // If we have over 10 mega bytes of chunks do the upsert
@@ -326,10 +374,18 @@ impl PostgresML {
.upsert_documents(chunks_to_upsert, None) .upsert_documents(chunks_to_upsert, None)
.await .await
.context("PGML - error upserting documents during resync")?; .context("PGML - error upserting documents during resync")?;
chunks_to_upsert = vec![];
current_chunks_bytes = 0;
} }
chunks_to_upsert = vec![];
} }
} }
// Upsert any remaining chunks
if chunks_to_upsert.len() > 0 {
collection
.upsert_documents(chunks_to_upsert, None)
.await
.context("PGML - error upserting documents during resync")?;
}
// Delete documents // Delete documents
if !documents_to_delete.is_empty() { if !documents_to_delete.is_empty() {
collection collection
@@ -382,7 +438,14 @@ impl PostgresML {
.splitter .splitter
.split_file_contents(&uri, &contents) .split_file_contents(&uri, &contents)
.into_iter() .into_iter()
.map(|chunk| chunk_to_document(&uri, chunk).into()) .map(|chunk| {
chunk_to_document(
&uri,
chunk,
self.config.client_params.root_uri.as_deref(),
)
.into()
})
.collect(); .collect();
documents.extend(chunks); documents.extend(chunks);
// If we have over 10 mega bytes of data do the upsert // If we have over 10 mega bytes of data do the upsert
@@ -440,17 +503,28 @@ impl MemoryBackend for PostgresML {
) -> anyhow::Result<Prompt> { ) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?; let params: MemoryRunParams = params.try_into()?;
// TOOD: FIGURE THIS OUT
// let prompt_size = params.max_context_length
// Build the query // Build the query
let query = self let query = self
.file_store .file_store
.get_characters_around_position(position, 512)?; .get_characters_around_position(position, 512)?;
// Get the code around the Cursor // Build the prompt
let mut file_store_params = params.clone(); let mut file_store_params = params.clone();
file_store_params.max_context_length = 512; file_store_params.max_context_length = 512;
let code = self let code = self
.file_store .file_store
.build_code(position, prompt_type, file_store_params)?; .build_code(position, prompt_type, file_store_params, false)?;
// Get the byte of the cursor
let cursor_byte = self.file_store.position_to_byte(position)?;
eprintln!(
"CURSOR BYTE: {} IN DOCUMENT: {}",
cursor_byte,
position.text_document.uri.to_string()
);
// Get the context // Get the context
let limit = params.max_context_length / 512; let limit = params.max_context_length / 512;
@@ -467,6 +541,29 @@ impl MemoryBackend for PostgresML {
} }
} }
}, },
"filter": {
"$or": [
{
"uri": {
"$ne": position.text_document.uri.to_string()
}
},
{
"range": {
"start": {
"$gt": cursor_byte
},
},
},
{
"range": {
"end": {
"$lt": cursor_byte
},
}
}
]
}
}, },
"limit": limit "limit": limit
}) })
@@ -485,6 +582,8 @@ impl MemoryBackend for PostgresML {
.collect::<anyhow::Result<Vec<String>>>()? .collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n"); .join("\n\n");
eprintln!("THE CONTEXT:\n\n{context}\n\n");
let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512)); let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512));
let context = &context[..chars.min(context.len())]; let context = &context[..chars.min(context.len())];
@@ -512,9 +611,17 @@ impl MemoryBackend for PostgresML {
let mut collection = self.collection.clone(); let mut collection = self.collection.clone();
let file_store = self.file_store.clone(); let file_store = self.file_store.clone();
let splitter = self.splitter.clone(); let splitter = self.splitter.clone();
let root_uri = self.config.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move { TOKIO_RUNTIME.spawn(async move {
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
if let Err(e) = split_and_upsert_file(&uri, &mut collection, file_store, splitter).await if let Err(e) = split_and_upsert_file(
&uri,
&mut collection,
file_store,
splitter,
root_uri.as_deref(),
)
.await
{ {
error!("{e:?}") error!("{e:?}")
} }
@@ -544,6 +651,7 @@ impl MemoryBackend for PostgresML {
let mut collection = self.collection.clone(); let mut collection = self.collection.clone();
let file_store = self.file_store.clone(); let file_store = self.file_store.clone();
let splitter = self.splitter.clone(); let splitter = self.splitter.clone();
let root_uri = self.config.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move { TOKIO_RUNTIME.spawn(async move {
for file in params.files { for file in params.files {
if let Err(e) = collection if let Err(e) = collection
@@ -564,6 +672,7 @@ impl MemoryBackend for PostgresML {
&mut collection, &mut collection,
file_store.clone(), file_store.clone(),
splitter.clone(), splitter.clone(),
root_uri.as_deref(),
) )
.await .await
{ {

View File

@@ -39,6 +39,8 @@ pub trait Splitter {
fn does_use_tree_sitter(&self) -> bool { fn does_use_tree_sitter(&self) -> bool {
false false
} }
fn chunk_size(&self) -> usize;
} }
impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> { impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {

View File

@@ -3,18 +3,21 @@ use crate::{config, memory_backends::file_store::File};
use super::{ByteRange, Chunk, Splitter}; use super::{ByteRange, Chunk, Splitter};
pub struct TextSplitter { pub struct TextSplitter {
chunk_size: usize,
splitter: text_splitter::TextSplitter<text_splitter::Characters>, splitter: text_splitter::TextSplitter<text_splitter::Characters>,
} }
impl TextSplitter { impl TextSplitter {
pub fn new(config: config::TextSplitter) -> Self { pub fn new(config: config::TextSplitter) -> Self {
Self { Self {
chunk_size: config.chunk_size,
splitter: text_splitter::TextSplitter::new(config.chunk_size), splitter: text_splitter::TextSplitter::new(config.chunk_size),
} }
} }
pub fn new_with_chunk_size(chunk_size: usize) -> Self { pub fn new_with_chunk_size(chunk_size: usize) -> Self {
Self { Self {
chunk_size,
splitter: text_splitter::TextSplitter::new(chunk_size), splitter: text_splitter::TextSplitter::new(chunk_size),
} }
} }
@@ -37,4 +40,8 @@ impl Splitter for TextSplitter {
acc acc
}) })
} }
fn chunk_size(&self) -> usize {
self.chunk_size
}
} }

View File

@@ -7,6 +7,7 @@ use crate::{config, memory_backends::file_store::File, utils::parse_tree};
use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter}; use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter};
pub struct TreeSitter { pub struct TreeSitter {
chunk_size: usize,
splitter: TreeSitterCodeSplitter, splitter: TreeSitterCodeSplitter,
text_splitter: TextSplitter, text_splitter: TextSplitter,
} }
@@ -15,6 +16,7 @@ impl TreeSitter {
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> { pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size); let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size);
Ok(Self { Ok(Self {
chunk_size: config.chunk_size,
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?, splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
text_splitter, text_splitter,
}) })
@@ -75,4 +77,8 @@ impl Splitter for TreeSitter {
fn does_use_tree_sitter(&self) -> bool { fn does_use_tree_sitter(&self) -> bool {
true true
} }
fn chunk_size(&self) -> usize {
self.chunk_size
}
} }

View File

@@ -156,6 +156,12 @@ impl OpenAI {
messages: Vec<ChatMessage>, messages: Vec<ChatMessage>,
params: OpenAIRunParams, params: OpenAIRunParams,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
eprintln!("\n\n\n\n");
for message in &messages {
eprintln!("{}:\n{}\n", message.role.to_string(), message.content);
}
eprintln!("\n\n\n\n");
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let token = self.get_token()?; let token = self.get_token()?;
let res: OpenAIChatResponse = client let res: OpenAIChatResponse = client