diff --git a/crates/lsp-ai/src/config.rs b/crates/lsp-ai/src/config.rs index 8474cea..f255464 100644 --- a/crates/lsp-ai/src/config.rs +++ b/crates/lsp-ai/src/config.rs @@ -69,10 +69,44 @@ pub struct TextSplitter { pub chunk_size: usize, } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct EmbeddingPrefix { + #[serde(default)] + pub storage: String, + #[serde(default)] + pub retrieval: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct OllamaEmbeddingModel { + // The generate endpoint, default: 'http://localhost:11434/api/embeddings' + pub endpoint: Option, + // The model name + pub model: String, + // The prefix to apply to the embeddings + #[serde(default)] + pub prefix: EmbeddingPrefix, +} + +#[derive(Debug, Clone, Deserialize)] +pub enum ValidEmbeddingModel { + Ollama(OllamaEmbeddingModel), +} + +#[derive(Debug, Clone, Deserialize)] +pub struct VectorStore { + pub crawl: Option, + #[serde(default)] + pub splitter: ValidSplitter, + pub embedding_model: ValidEmbeddingModel, +} + #[derive(Debug, Clone, Deserialize)] pub(crate) enum ValidMemoryBackend { #[serde(rename = "file_store")] FileStore(FileStore), + #[serde(rename = "vector_store")] + VectorStore(VectorStore), #[serde(rename = "postgresml")] PostgresML(PostgresML), } diff --git a/crates/lsp-ai/src/embedding_models/mod.rs b/crates/lsp-ai/src/embedding_models/mod.rs new file mode 100644 index 0000000..a90576b --- /dev/null +++ b/crates/lsp-ai/src/embedding_models/mod.rs @@ -0,0 +1,28 @@ +use crate::config::ValidEmbeddingModel; + +mod ollama; + +#[derive(Clone, Copy)] +pub enum EmbeddingPurpose { + Storage, + Retrieval, +} + +#[async_trait::async_trait] +pub trait EmbeddingModel { + async fn embed( + &self, + batch: &[&str], + purpose: EmbeddingPurpose, + ) -> anyhow::Result>>; +} + +impl TryFrom for Box { + type Error = anyhow::Error; + + fn try_from(value: ValidEmbeddingModel) -> Result { + match value { + ValidEmbeddingModel::Ollama(config) => Ok(Box::new(ollama::Ollama::new(config))), + } + } +} diff --git a/crates/lsp-ai/src/embedding_models/ollama.rs b/crates/lsp-ai/src/embedding_models/ollama.rs new file mode 100644 index 0000000..ef337a6 --- /dev/null +++ b/crates/lsp-ai/src/embedding_models/ollama.rs @@ -0,0 +1,102 @@ +use std::collections::HashMap; + +use serde::Deserialize; +use serde_json::{json, Value}; + +use crate::config; + +use super::{EmbeddingModel, EmbeddingPurpose}; + +#[derive(Deserialize)] +pub struct EmbedResponse { + embedding: Option>, + error: Option, + #[serde(default)] + #[serde(flatten)] + other: HashMap, +} + +pub struct Ollama { + config: config::OllamaEmbeddingModel, +} + +impl Ollama { + pub fn new(config: config::OllamaEmbeddingModel) -> Self { + Self { config } + } +} + +#[async_trait::async_trait] +impl EmbeddingModel for Ollama { + async fn embed( + &self, + batch: &[&str], + purpose: EmbeddingPurpose, + ) -> anyhow::Result>> { + let mut results = vec![]; + let prefix = match purpose { + EmbeddingPurpose::Storage => &self.config.prefix.storage, + EmbeddingPurpose::Retrieval => &self.config.prefix.retrieval, + }; + let client = reqwest::Client::new(); + for item in batch { + let prompt = format!("{prefix}{item}"); + let res: EmbedResponse = client + .post( + self.config + .endpoint + .as_deref() + .unwrap_or("http://localhost:11434/api/embeddings"), + ) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .json(&json!({ + "model": self.config.model, + "prompt": prompt + })) + .send() + .await? + .json() + .await?; + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(embedding) = res.embedding { + results.push(embedding); + } else { + anyhow::bail!( + "Unknown error while making request to Ollama: {:?}", + res.other + ) + } + } + Ok(results) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn ollama_embeding() -> anyhow::Result<()> { + let configuration: config::OllamaEmbeddingModel = serde_json::from_value(json!({ + "model": "nomic-embed-text", + "prefix": { + "retrieval": "search_query", + "storage": "search_document" + } + }))?; + + let ollama = Ollama::new(configuration); + let results = ollama + .embed( + &["Hello world!", "How are you?"], + EmbeddingPurpose::Retrieval, + ) + .await?; + assert_eq!(results.len(), 2); + assert_eq!(results[0].len(), 768); + + Ok(()) + } +} diff --git a/crates/lsp-ai/src/main.rs b/crates/lsp-ai/src/main.rs index 106be0a..09b14cd 100644 --- a/crates/lsp-ai/src/main.rs +++ b/crates/lsp-ai/src/main.rs @@ -16,6 +16,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber}; mod config; mod crawl; mod custom_requests; +mod embedding_models; mod memory_backends; mod memory_worker; mod splitters; diff --git a/crates/lsp-ai/src/memory_backends/mod.rs b/crates/lsp-ai/src/memory_backends/mod.rs index 8fa5914..2bb13d1 100644 --- a/crates/lsp-ai/src/memory_backends/mod.rs +++ b/crates/lsp-ai/src/memory_backends/mod.rs @@ -8,6 +8,7 @@ use crate::config::{Config, ValidMemoryBackend}; pub(crate) mod file_store; mod postgresml; +mod vector_store; #[derive(Clone, Debug)] pub enum PromptType { @@ -136,6 +137,9 @@ impl TryFrom for Box { ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new( postgresml::PostgresML::new(postgresml_config, configuration)?, )), + ValidMemoryBackend::VectorStore(vector_store_config) => Ok(Box::new( + vector_store::VectorStore::new(vector_store_config, configuration)?, + )), } } } diff --git a/crates/lsp-ai/src/memory_backends/vector_store.rs b/crates/lsp-ai/src/memory_backends/vector_store.rs new file mode 100644 index 0000000..07fac7c --- /dev/null +++ b/crates/lsp-ai/src/memory_backends/vector_store.rs @@ -0,0 +1,92 @@ +use std::sync::Arc; + +use anyhow::Context; +use lsp_types::{ + DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, + TextDocumentPositionParams, +}; +use parking_lot::Mutex; +use serde_json::Value; + +use crate::{ + config::{self, Config}, + crawl::Crawl, + splitters::Splitter, +}; + +use super::{ + file_store::{AdditionalFileStoreParams, FileStore}, + MemoryBackend, Prompt, PromptType, +}; + +pub struct VectorStore { + file_store: FileStore, + // TODO: Verify we need these Arc<> + crawl: Option>>, + splitter: Arc>, +} + +impl VectorStore { + pub fn new( + mut vector_store_config: config::VectorStore, + config: Config, + ) -> anyhow::Result { + let crawl = vector_store_config + .crawl + .take() + .map(|x| Arc::new(Mutex::new(Crawl::new(x, config.clone())))); + + let splitter: Arc> = + Arc::new(vector_store_config.splitter.clone().try_into()?); + + let file_store = FileStore::new_with_params( + config::FileStore::new_without_crawl(), + config.clone(), + AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()), + )?; + + Ok(Self { + file_store, + crawl, + splitter, + }) + } +} + +#[async_trait::async_trait] +impl MemoryBackend for VectorStore { + fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()> { + // Pass through + let uri = params.text_document.uri.to_string(); + self.file_store.opened_text_document(params)?; + // Split into chunks + let file_map = self.file_store.file_map().lock(); + let file = file_map.get(&uri).context("file not found")?; + let chunks = self.splitter.split(file); + // Embed it + Ok(()) + } + + fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()> { + self.file_store.changed_text_document(params.clone())?; + Ok(()) + } + + fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()> { + self.file_store.renamed_files(params.clone())?; + Ok(()) + } + + fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + self.file_store.get_filter_text(position) + } + + async fn build_prompt( + &self, + position: &TextDocumentPositionParams, + prompt_type: PromptType, + params: &Value, + ) -> anyhow::Result { + todo!() + } +}