Added PostgresML memory backend

This commit is contained in:
Silas Marvin
2024-03-10 15:29:47 -07:00
parent 2081c40759
commit 047215fc31
10 changed files with 1809 additions and 107 deletions

View File

@@ -4,7 +4,7 @@ use ropey::Rope;
use std::collections::HashMap;
use tracing::instrument;
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
use crate::{configuration::Configuration, utils::tokens_to_estimated_characters};
use super::{MemoryBackend, Prompt, PromptForType};
@@ -20,28 +20,35 @@ impl FileStore {
file_map: HashMap::new(),
}
}
}
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
pub fn get_characters_around_position(
&self,
position: &TextDocumentPositionParams,
characters: usize,
) -> anyhow::Result<String> {
let rope = self
.file_map
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
Ok(rope
.get_line(position.position.line as usize)
.context("Error getting filter_text")?
.to_string())
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index.checked_sub(characters / 2).unwrap_or(0);
let end = rope
.len_chars()
.min(cursor_index + (characters - (cursor_index - start)));
let rope_slice = rope
.get_slice(start..end)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
#[instrument(skip(self))]
fn build_prompt(
pub fn build_code(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
max_context_length: usize,
) -> anyhow::Result<String> {
let mut rope = self
.file_map
.get(position.text_document.uri.as_str())
@@ -66,13 +73,11 @@ impl MemoryBackend for FileStore {
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
// and the user has not enabled chat
let code = match (is_chat_enabled, self.configuration.get_fim()?) {
Ok(match (is_chat_enabled, self.configuration.get_fim()?) {
r @ (true, _) | r @ (false, Some(_))
if is_chat_enabled || rope.len_chars() != cursor_index =>
{
let max_length = characters_to_estimated_tokens(
self.configuration.get_maximum_context_length()?,
);
let max_length = tokens_to_estimated_characters(max_context_length);
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let end = rope
.len_chars()
@@ -103,16 +108,42 @@ impl MemoryBackend for FileStore {
}
_ => {
let start = cursor_index
.checked_sub(characters_to_estimated_tokens(
self.configuration.get_maximum_context_length()?,
))
.checked_sub(tokens_to_estimated_characters(max_context_length))
.unwrap_or(0);
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
rope_slice.to_string()
}
};
})
}
}
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
.file_map
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
Ok(rope
.get_line(position.position.line as usize)
.context("Error getting filter_text")?
.to_string())
}
#[instrument(skip(self))]
fn build_prompt(
&mut self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
let code = self.build_code(
position,
prompt_for_type,
self.configuration.get_max_context_length()?,
)?;
Ok(Prompt::new("".to_string(), code))
}