mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 15:04:29 +01:00
Added vector store backend and ollama embedding option
This commit is contained in:
@@ -69,10 +69,44 @@ pub struct TextSplitter {
|
||||
pub chunk_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct EmbeddingPrefix {
|
||||
#[serde(default)]
|
||||
pub storage: String,
|
||||
#[serde(default)]
|
||||
pub retrieval: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OllamaEmbeddingModel {
|
||||
// The generate endpoint, default: 'http://localhost:11434/api/embeddings'
|
||||
pub endpoint: Option<String>,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// The prefix to apply to the embeddings
|
||||
#[serde(default)]
|
||||
pub prefix: EmbeddingPrefix,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub enum ValidEmbeddingModel {
|
||||
Ollama(OllamaEmbeddingModel),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct VectorStore {
|
||||
pub crawl: Option<Crawl>,
|
||||
#[serde(default)]
|
||||
pub splitter: ValidSplitter,
|
||||
pub embedding_model: ValidEmbeddingModel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub(crate) enum ValidMemoryBackend {
|
||||
#[serde(rename = "file_store")]
|
||||
FileStore(FileStore),
|
||||
#[serde(rename = "vector_store")]
|
||||
VectorStore(VectorStore),
|
||||
#[serde(rename = "postgresml")]
|
||||
PostgresML(PostgresML),
|
||||
}
|
||||
|
||||
28
crates/lsp-ai/src/embedding_models/mod.rs
Normal file
28
crates/lsp-ai/src/embedding_models/mod.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use crate::config::ValidEmbeddingModel;
|
||||
|
||||
mod ollama;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum EmbeddingPurpose {
|
||||
Storage,
|
||||
Retrieval,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait EmbeddingModel {
|
||||
async fn embed(
|
||||
&self,
|
||||
batch: &[&str],
|
||||
purpose: EmbeddingPurpose,
|
||||
) -> anyhow::Result<Vec<Vec<f32>>>;
|
||||
}
|
||||
|
||||
impl TryFrom<ValidEmbeddingModel> for Box<dyn EmbeddingModel> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: ValidEmbeddingModel) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ValidEmbeddingModel::Ollama(config) => Ok(Box::new(ollama::Ollama::new(config))),
|
||||
}
|
||||
}
|
||||
}
|
||||
102
crates/lsp-ai/src/embedding_models/ollama.rs
Normal file
102
crates/lsp-ai/src/embedding_models/ollama.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::config;
|
||||
|
||||
use super::{EmbeddingModel, EmbeddingPurpose};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EmbedResponse {
|
||||
embedding: Option<Vec<f32>>,
|
||||
error: Option<Value>,
|
||||
#[serde(default)]
|
||||
#[serde(flatten)]
|
||||
other: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
pub struct Ollama {
|
||||
config: config::OllamaEmbeddingModel,
|
||||
}
|
||||
|
||||
impl Ollama {
|
||||
pub fn new(config: config::OllamaEmbeddingModel) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl EmbeddingModel for Ollama {
|
||||
async fn embed(
|
||||
&self,
|
||||
batch: &[&str],
|
||||
purpose: EmbeddingPurpose,
|
||||
) -> anyhow::Result<Vec<Vec<f32>>> {
|
||||
let mut results = vec![];
|
||||
let prefix = match purpose {
|
||||
EmbeddingPurpose::Storage => &self.config.prefix.storage,
|
||||
EmbeddingPurpose::Retrieval => &self.config.prefix.retrieval,
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
for item in batch {
|
||||
let prompt = format!("{prefix}{item}");
|
||||
let res: EmbedResponse = client
|
||||
.post(
|
||||
self.config
|
||||
.endpoint
|
||||
.as_deref()
|
||||
.unwrap_or("http://localhost:11434/api/embeddings"),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&json!({
|
||||
"model": self.config.model,
|
||||
"prompt": prompt
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(embedding) = res.embedding {
|
||||
results.push(embedding);
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"Unknown error while making request to Ollama: {:?}",
|
||||
res.other
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn ollama_embeding() -> anyhow::Result<()> {
|
||||
let configuration: config::OllamaEmbeddingModel = serde_json::from_value(json!({
|
||||
"model": "nomic-embed-text",
|
||||
"prefix": {
|
||||
"retrieval": "search_query",
|
||||
"storage": "search_document"
|
||||
}
|
||||
}))?;
|
||||
|
||||
let ollama = Ollama::new(configuration);
|
||||
let results = ollama
|
||||
.embed(
|
||||
&["Hello world!", "How are you?"],
|
||||
EmbeddingPurpose::Retrieval,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].len(), 768);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
mod config;
|
||||
mod crawl;
|
||||
mod custom_requests;
|
||||
mod embedding_models;
|
||||
mod memory_backends;
|
||||
mod memory_worker;
|
||||
mod splitters;
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::config::{Config, ValidMemoryBackend};
|
||||
|
||||
pub(crate) mod file_store;
|
||||
mod postgresml;
|
||||
mod vector_store;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum PromptType {
|
||||
@@ -136,6 +137,9 @@ impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
|
||||
ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new(
|
||||
postgresml::PostgresML::new(postgresml_config, configuration)?,
|
||||
)),
|
||||
ValidMemoryBackend::VectorStore(vector_store_config) => Ok(Box::new(
|
||||
vector_store::VectorStore::new(vector_store_config, configuration)?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
92
crates/lsp-ai/src/memory_backends/vector_store.rs
Normal file
92
crates/lsp-ai/src/memory_backends/vector_store.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use lsp_types::{
|
||||
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
||||
TextDocumentPositionParams,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
config::{self, Config},
|
||||
crawl::Crawl,
|
||||
splitters::Splitter,
|
||||
};
|
||||
|
||||
use super::{
|
||||
file_store::{AdditionalFileStoreParams, FileStore},
|
||||
MemoryBackend, Prompt, PromptType,
|
||||
};
|
||||
|
||||
pub struct VectorStore {
|
||||
file_store: FileStore,
|
||||
// TODO: Verify we need these Arc<>
|
||||
crawl: Option<Arc<Mutex<Crawl>>>,
|
||||
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl VectorStore {
|
||||
pub fn new(
|
||||
mut vector_store_config: config::VectorStore,
|
||||
config: Config,
|
||||
) -> anyhow::Result<Self> {
|
||||
let crawl = vector_store_config
|
||||
.crawl
|
||||
.take()
|
||||
.map(|x| Arc::new(Mutex::new(Crawl::new(x, config.clone()))));
|
||||
|
||||
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
|
||||
Arc::new(vector_store_config.splitter.clone().try_into()?);
|
||||
|
||||
let file_store = FileStore::new_with_params(
|
||||
config::FileStore::new_without_crawl(),
|
||||
config.clone(),
|
||||
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
file_store,
|
||||
crawl,
|
||||
splitter,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for VectorStore {
|
||||
fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()> {
|
||||
// Pass through
|
||||
let uri = params.text_document.uri.to_string();
|
||||
self.file_store.opened_text_document(params)?;
|
||||
// Split into chunks
|
||||
let file_map = self.file_store.file_map().lock();
|
||||
let file = file_map.get(&uri).context("file not found")?;
|
||||
let chunks = self.splitter.split(file);
|
||||
// Embed it
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()> {
|
||||
self.file_store.changed_text_document(params.clone())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()> {
|
||||
self.file_store.renamed_files(params.clone())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position)
|
||||
}
|
||||
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: &Value,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user