Added vector store backend and ollama embedding option

This commit is contained in:
SilasMarvin
2024-06-24 22:31:02 -07:00
parent 1f70756c5b
commit 0ac708cce0
6 changed files with 261 additions and 0 deletions

View File

@@ -69,10 +69,44 @@ pub struct TextSplitter {
pub chunk_size: usize,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct EmbeddingPrefix {
#[serde(default)]
pub storage: String,
#[serde(default)]
pub retrieval: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaEmbeddingModel {
// The generate endpoint, default: 'http://localhost:11434/api/embeddings'
pub endpoint: Option<String>,
// The model name
pub model: String,
// The prefix to apply to the embeddings
#[serde(default)]
pub prefix: EmbeddingPrefix,
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidEmbeddingModel {
Ollama(OllamaEmbeddingModel),
}
#[derive(Debug, Clone, Deserialize)]
pub struct VectorStore {
pub crawl: Option<Crawl>,
#[serde(default)]
pub splitter: ValidSplitter,
pub embedding_model: ValidEmbeddingModel,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) enum ValidMemoryBackend {
#[serde(rename = "file_store")]
FileStore(FileStore),
#[serde(rename = "vector_store")]
VectorStore(VectorStore),
#[serde(rename = "postgresml")]
PostgresML(PostgresML),
}

View File

@@ -0,0 +1,28 @@
use crate::config::ValidEmbeddingModel;
mod ollama;
#[derive(Clone, Copy)]
pub enum EmbeddingPurpose {
Storage,
Retrieval,
}
#[async_trait::async_trait]
pub trait EmbeddingModel {
async fn embed(
&self,
batch: &[&str],
purpose: EmbeddingPurpose,
) -> anyhow::Result<Vec<Vec<f32>>>;
}
impl TryFrom<ValidEmbeddingModel> for Box<dyn EmbeddingModel> {
type Error = anyhow::Error;
fn try_from(value: ValidEmbeddingModel) -> Result<Self, Self::Error> {
match value {
ValidEmbeddingModel::Ollama(config) => Ok(Box::new(ollama::Ollama::new(config))),
}
}
}

View File

@@ -0,0 +1,102 @@
use std::collections::HashMap;
use serde::Deserialize;
use serde_json::{json, Value};
use crate::config;
use super::{EmbeddingModel, EmbeddingPurpose};
#[derive(Deserialize)]
pub struct EmbedResponse {
embedding: Option<Vec<f32>>,
error: Option<Value>,
#[serde(default)]
#[serde(flatten)]
other: HashMap<String, Value>,
}
pub struct Ollama {
config: config::OllamaEmbeddingModel,
}
impl Ollama {
pub fn new(config: config::OllamaEmbeddingModel) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl EmbeddingModel for Ollama {
async fn embed(
&self,
batch: &[&str],
purpose: EmbeddingPurpose,
) -> anyhow::Result<Vec<Vec<f32>>> {
let mut results = vec![];
let prefix = match purpose {
EmbeddingPurpose::Storage => &self.config.prefix.storage,
EmbeddingPurpose::Retrieval => &self.config.prefix.retrieval,
};
let client = reqwest::Client::new();
for item in batch {
let prompt = format!("{prefix}{item}");
let res: EmbedResponse = client
.post(
self.config
.endpoint
.as_deref()
.unwrap_or("http://localhost:11434/api/embeddings"),
)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"model": self.config.model,
"prompt": prompt
}))
.send()
.await?
.json()
.await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(embedding) = res.embedding {
results.push(embedding);
} else {
anyhow::bail!(
"Unknown error while making request to Ollama: {:?}",
res.other
)
}
}
Ok(results)
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn ollama_embeding() -> anyhow::Result<()> {
let configuration: config::OllamaEmbeddingModel = serde_json::from_value(json!({
"model": "nomic-embed-text",
"prefix": {
"retrieval": "search_query",
"storage": "search_document"
}
}))?;
let ollama = Ollama::new(configuration);
let results = ollama
.embed(
&["Hello world!", "How are you?"],
EmbeddingPurpose::Retrieval,
)
.await?;
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 768);
Ok(())
}
}

View File

@@ -16,6 +16,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod config;
mod crawl;
mod custom_requests;
mod embedding_models;
mod memory_backends;
mod memory_worker;
mod splitters;

View File

@@ -8,6 +8,7 @@ use crate::config::{Config, ValidMemoryBackend};
pub(crate) mod file_store;
mod postgresml;
mod vector_store;
#[derive(Clone, Debug)]
pub enum PromptType {
@@ -136,6 +137,9 @@ impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new(
postgresml::PostgresML::new(postgresml_config, configuration)?,
)),
ValidMemoryBackend::VectorStore(vector_store_config) => Ok(Box::new(
vector_store::VectorStore::new(vector_store_config, configuration)?,
)),
}
}
}

View File

@@ -0,0 +1,92 @@
use std::sync::Arc;
use anyhow::Context;
use lsp_types::{
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
TextDocumentPositionParams,
};
use parking_lot::Mutex;
use serde_json::Value;
use crate::{
config::{self, Config},
crawl::Crawl,
splitters::Splitter,
};
use super::{
file_store::{AdditionalFileStoreParams, FileStore},
MemoryBackend, Prompt, PromptType,
};
pub struct VectorStore {
file_store: FileStore,
// TODO: Verify we need these Arc<>
crawl: Option<Arc<Mutex<Crawl>>>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
}
impl VectorStore {
pub fn new(
mut vector_store_config: config::VectorStore,
config: Config,
) -> anyhow::Result<Self> {
let crawl = vector_store_config
.crawl
.take()
.map(|x| Arc::new(Mutex::new(Crawl::new(x, config.clone()))));
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
Arc::new(vector_store_config.splitter.clone().try_into()?);
let file_store = FileStore::new_with_params(
config::FileStore::new_without_crawl(),
config.clone(),
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
)?;
Ok(Self {
file_store,
crawl,
splitter,
})
}
}
#[async_trait::async_trait]
impl MemoryBackend for VectorStore {
fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()> {
// Pass through
let uri = params.text_document.uri.to_string();
self.file_store.opened_text_document(params)?;
// Split into chunks
let file_map = self.file_store.file_map().lock();
let file = file_map.get(&uri).context("file not found")?;
let chunks = self.splitter.split(file);
// Embed it
Ok(())
}
fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()> {
self.file_store.changed_text_document(params.clone())?;
Ok(())
}
fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()> {
self.file_store.renamed_files(params.clone())?;
Ok(())
}
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
self.file_store.get_filter_text(position)
}
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt> {
todo!()
}
}