From d1a3e48be3ff91332ee1e570d52caa63a30c23b6 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 6 Apr 2024 15:12:44 -0700 Subject: [PATCH] Renamed configuration to config --- src/{configuration.rs => config.rs} | 14 +++++++------- src/main.rs | 6 +++--- src/memory_backends/file_store.rs | 13 +++++-------- src/memory_backends/mod.rs | 6 +++--- src/memory_backends/postgresml/mod.rs | 8 ++++---- src/template.rs | 2 +- src/transformer_backends/anthropic.rs | 10 +++++----- src/transformer_backends/llama_cpp/mod.rs | 8 ++++---- src/transformer_backends/llama_cpp/model.rs | 2 +- src/transformer_backends/mod.rs | 6 +++--- src/transformer_backends/openai/mod.rs | 14 +++++++------- src/transformer_worker.rs | 6 +++--- src/utils.rs | 2 +- 13 files changed, 47 insertions(+), 50 deletions(-) rename src/{configuration.rs => config.rs} (98%) diff --git a/src/configuration.rs b/src/config.rs similarity index 98% rename from src/configuration.rs rename to src/config.rs index 90c3b4a..cce73d4 100644 --- a/src/configuration.rs +++ b/src/config.rs @@ -228,7 +228,7 @@ pub struct Anthropic { } #[derive(Clone, Debug, Deserialize, Default)] -pub struct ValidConfiguration { +pub struct ValidConfig { #[serde(default)] pub memory: ValidMemoryBackend, #[serde(default)] @@ -243,12 +243,12 @@ pub struct ValidClientParams { } #[derive(Clone, Debug)] -pub struct Configuration { - pub config: ValidConfiguration, +pub struct Config { + pub config: ValidConfig, _client_params: ValidClientParams, } -impl Configuration { +impl Config { pub fn new(mut args: Value) -> Result { let configuration_args = args .as_object_mut() @@ -256,7 +256,7 @@ impl Configuration { .remove("initializationOptions"); let valid_args = match configuration_args { Some(configuration_args) => serde_json::from_value(configuration_args)?, - None => ValidConfiguration::default(), + None => ValidConfig::default(), }; let client_params: ValidClientParams = serde_json::from_value(args)?; Ok(Self { @@ -364,7 +364,7 @@ mod test { } } }); - Configuration::new(args).unwrap(); + Config::new(args).unwrap(); } #[test] @@ -391,6 +391,6 @@ mod test { } } }); - Configuration::new(args).unwrap(); + Config::new(args).unwrap(); } } diff --git a/src/main.rs b/src/main.rs index a6f9427..dbb57c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use std::{ use tracing::error; use tracing_subscriber::{EnvFilter, FmtSubscriber}; -mod configuration; +mod config; mod custom_requests; mod memory_backends; mod memory_worker; @@ -21,7 +21,7 @@ mod transformer_backends; mod transformer_worker; mod utils; -use configuration::Configuration; +use config::Config; use custom_requests::generate::Generate; use memory_backends::MemoryBackend; use transformer_backends::TransformerBackend; @@ -74,7 +74,7 @@ fn main() -> Result<()> { fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { // Build our configuration - let config = Configuration::new(args)?; + let config = Config::new(args)?; // Wrap the connection for sharing between threads let connection = Arc::new(connection); diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 7d10769..125cf5d 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use tracing::instrument; use crate::{ - configuration::{self, Configuration}, + config::{self, Config}, utils::tokens_to_estimated_characters, }; @@ -15,13 +15,13 @@ use super::{MemoryBackend, Prompt, PromptForType}; pub struct FileStore { _crawl: bool, - configuration: Configuration, + configuration: Config, file_map: Mutex>, accessed_files: Mutex>, } impl FileStore { - pub fn new(file_store_config: configuration::FileStore, configuration: Configuration) -> Self { + pub fn new(file_store_config: config::FileStore, configuration: Config) -> Self { Self { _crawl: file_store_config.crawl, configuration, @@ -30,7 +30,7 @@ impl FileStore { } } - pub fn new_without_crawl(configuration: Configuration) -> Self { + pub fn new_without_crawl(configuration: Config) -> Self { Self { _crawl: false, configuration, @@ -122,11 +122,8 @@ impl FileStore { .unwrap_or(false), }; - // We only want to do FIM if the user has enabled it and the user has not enabled chat Ok(match (is_chat_enabled, self.configuration.get_fim()) { - r @ (true, _) | r @ (false, Some(_)) - if is_chat_enabled || rope.len_chars() != cursor_index => - { + r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => { let max_length = tokens_to_estimated_characters(max_context_length); let start = cursor_index.saturating_sub(max_length / 2); let end = rope diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index a1f0aeb..d810434 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -3,7 +3,7 @@ use lsp_types::{ TextDocumentPositionParams, }; -use crate::configuration::{Configuration, ValidMemoryBackend}; +use crate::config::{Config, ValidMemoryBackend}; pub mod file_store; mod postgresml; @@ -48,10 +48,10 @@ pub trait MemoryBackend { ) -> anyhow::Result; } -impl TryFrom for Box { +impl TryFrom for Box { type Error = anyhow::Error; - fn try_from(configuration: Configuration) -> Result { + fn try_from(configuration: Config) -> Result { match configuration.config.memory.clone() { ValidMemoryBackend::FileStore(file_store_config) => Ok(Box::new( file_store::FileStore::new(file_store_config, configuration), diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index d3e450a..ab05987 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -11,14 +11,14 @@ use tokio::time; use tracing::instrument; use crate::{ - configuration::{self, Configuration}, + config::{self, Config}, utils::tokens_to_estimated_characters, }; use super::{file_store::FileStore, MemoryBackend, Prompt, PromptForType}; pub struct PostgresML { - configuration: Configuration, + configuration: Config, file_store: FileStore, collection: Collection, pipeline: Pipeline, @@ -28,8 +28,8 @@ pub struct PostgresML { impl PostgresML { pub fn new( - postgresml_config: configuration::PostgresML, - configuration: Configuration, + postgresml_config: config::PostgresML, + configuration: Config, ) -> anyhow::Result { let file_store = FileStore::new_without_crawl(configuration.clone()); let database_url = if let Some(database_url) = postgresml_config.database_url { diff --git a/src/template.rs b/src/template.rs index 3d75141..c98634f 100644 --- a/src/template.rs +++ b/src/template.rs @@ -2,7 +2,7 @@ use minijinja::{context, Environment, ErrorKind}; use once_cell::sync::Lazy; use parking_lot::Mutex; -use crate::configuration::ChatMessage; +use crate::config::ChatMessage; static MINININJA_ENVIRONMENT: Lazy> = Lazy::new(|| Mutex::new(Environment::new())); diff --git a/src/transformer_backends/anthropic.rs b/src/transformer_backends/anthropic.rs index be47c79..a04adc0 100644 --- a/src/transformer_backends/anthropic.rs +++ b/src/transformer_backends/anthropic.rs @@ -4,7 +4,7 @@ use serde_json::{json, Value}; use tracing::instrument; use crate::{ - configuration::{self, ChatMessage}, + config::{self, ChatMessage}, memory_backends::Prompt, transformer_worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, @@ -15,7 +15,7 @@ use crate::{ use super::TransformerBackend; pub struct Anthropic { - configuration: configuration::Anthropic, + configuration: config::Anthropic, } #[derive(Deserialize)] @@ -31,7 +31,7 @@ struct AnthropicChatResponse { impl Anthropic { #[instrument] - pub fn new(configuration: configuration::Anthropic) -> Self { + pub fn new(configuration: config::Anthropic) -> Self { Self { configuration } } @@ -139,7 +139,7 @@ mod test { #[tokio::test] async fn anthropic_chat_do_completion() -> anyhow::Result<()> { - let configuration: configuration::Anthropic = serde_json::from_value(json!({ + let configuration: config::Anthropic = serde_json::from_value(json!({ "chat_endpoint": "https://api.anthropic.com/v1/messages", "model": "claude-3-haiku-20240307", "auth_token_env_var_name": "ANTHROPIC_API_KEY", @@ -170,7 +170,7 @@ mod test { #[tokio::test] async fn anthropic_chat_do_generate() -> anyhow::Result<()> { - let configuration: configuration::Anthropic = serde_json::from_value(json!({ + let configuration: config::Anthropic = serde_json::from_value(json!({ "chat_endpoint": "https://api.anthropic.com/v1/messages", "model": "claude-3-haiku-20240307", "auth_token_env_var_name": "ANTHROPIC_API_KEY", diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index ec98cb5..dd2ca1d 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -3,7 +3,7 @@ use hf_hub::api::sync::ApiBuilder; use tracing::instrument; use crate::{ - configuration::{self}, + config::{self}, memory_backends::Prompt, template::apply_chat_template, transformer_worker::{ @@ -19,12 +19,12 @@ use super::TransformerBackend; pub struct LLaMACPP { model: Model, - configuration: configuration::LLaMACPP, + configuration: config::LLaMACPP, } impl LLaMACPP { #[instrument] - pub fn new(configuration: configuration::LLaMACPP) -> anyhow::Result { + pub fn new(configuration: config::LLaMACPP) -> anyhow::Result { let api = ApiBuilder::new().with_progress(true).build()?; let name = configuration .model @@ -99,7 +99,7 @@ mod test { #[tokio::test] async fn llama_cpp_do_completion() -> anyhow::Result<()> { - let configuration: configuration::LLaMACPP = serde_json::from_value(json!({ + let configuration: config::LLaMACPP = serde_json::from_value(json!({ "repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF", "name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf", "max_new_tokens": { diff --git a/src/transformer_backends/llama_cpp/model.rs b/src/transformer_backends/llama_cpp/model.rs index b813d1f..c5bcdfd 100644 --- a/src/transformer_backends/llama_cpp/model.rs +++ b/src/transformer_backends/llama_cpp/model.rs @@ -11,7 +11,7 @@ use once_cell::sync::Lazy; use std::{num::NonZeroU32, path::PathBuf, time::Duration}; use tracing::{debug, info, instrument}; -use crate::configuration::{ChatMessage, Kwargs}; +use crate::config::{ChatMessage, Kwargs}; static BACKEND: Lazy = Lazy::new(|| LlamaBackend::init().unwrap()); diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 87148a1..30b464c 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -1,5 +1,5 @@ use crate::{ - configuration::{Configuration, ValidTransformerBackend}, + config::{Config, ValidTransformerBackend}, memory_backends::Prompt, transformer_worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, @@ -20,10 +20,10 @@ pub trait TransformerBackend { ) -> anyhow::Result; } -impl TryFrom for Box { +impl TryFrom for Box { type Error = anyhow::Error; - fn try_from(configuration: Configuration) -> Result { + fn try_from(configuration: Config) -> Result { match configuration.config.transformer { ValidTransformerBackend::LLaMACPP(model_gguf) => { Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?)) diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index 052776c..aedc472 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -7,7 +7,7 @@ use serde_json::{json, Value}; use tracing::instrument; use crate::{ - configuration::{self, ChatMessage}, + config::{self, ChatMessage}, memory_backends::Prompt, transformer_worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, @@ -18,7 +18,7 @@ use crate::{ use super::TransformerBackend; pub struct OpenAI { - configuration: configuration::OpenAI, + configuration: config::OpenAI, } #[derive(Deserialize)] @@ -45,7 +45,7 @@ struct OpenAIChatResponse { impl OpenAI { #[instrument] - pub fn new(configuration: configuration::OpenAI) -> Self { + pub fn new(configuration: config::OpenAI) -> Self { Self { configuration } } @@ -201,7 +201,7 @@ mod test { #[tokio::test] async fn openai_completion_do_completion() -> anyhow::Result<()> { - let configuration: configuration::OpenAI = serde_json::from_value(json!({ + let configuration: config::OpenAI = serde_json::from_value(json!({ "completions_endpoint": "https://api.openai.com/v1/completions", "model": "gpt-3.5-turbo-instruct", "auth_token_env_var_name": "OPENAI_API_KEY", @@ -220,7 +220,7 @@ mod test { #[tokio::test] async fn openai_chat_do_completion() -> anyhow::Result<()> { - let configuration: configuration::OpenAI = serde_json::from_value(json!({ + let configuration: config::OpenAI = serde_json::from_value(json!({ "chat_endpoint": "https://api.openai.com/v1/chat/completions", "model": "gpt-3.5-turbo", "auth_token_env_var_name": "OPENAI_API_KEY", @@ -251,7 +251,7 @@ mod test { #[tokio::test] async fn openai_completion_do_generate() -> anyhow::Result<()> { - let configuration: configuration::OpenAI = serde_json::from_value(json!({ + let configuration: config::OpenAI = serde_json::from_value(json!({ "completions_endpoint": "https://api.openai.com/v1/completions", "model": "gpt-3.5-turbo-instruct", "auth_token_env_var_name": "OPENAI_API_KEY", @@ -270,7 +270,7 @@ mod test { #[tokio::test] async fn openai_chat_do_generate() -> anyhow::Result<()> { - let configuration: configuration::OpenAI = serde_json::from_value(json!({ + let configuration: config::OpenAI = serde_json::from_value(json!({ "chat_endpoint": "https://api.openai.com/v1/chat/completions", "model": "gpt-3.5-turbo", "auth_token_env_var_name": "OPENAI_API_KEY", diff --git a/src/transformer_worker.rs b/src/transformer_worker.rs index 7b68965..fcc6914 100644 --- a/src/transformer_worker.rs +++ b/src/transformer_worker.rs @@ -8,7 +8,7 @@ use std::time::SystemTime; use tokio::sync::oneshot; use tracing::{debug, error, instrument}; -use crate::configuration::Configuration; +use crate::config::Config; use crate::custom_requests::generate::{GenerateParams, GenerateResult}; use crate::custom_requests::generate_stream::GenerateStreamParams; use crate::memory_backends::PromptForType; @@ -117,7 +117,7 @@ fn do_run( memory_backend_tx: std::sync::mpsc::Sender, transformer_rx: std::sync::mpsc::Receiver, connection: Arc, - config: Configuration, + config: Config, ) -> anyhow::Result<()> { let transformer_backend = Arc::new(transformer_backend); let runtime = tokio::runtime::Builder::new_multi_thread() @@ -166,7 +166,7 @@ pub fn run( memory_tx: std::sync::mpsc::Sender, transformer_rx: std::sync::mpsc::Receiver, connection: Arc, - config: Configuration, + config: Config, ) { if let Err(e) = do_run( transformer_backend, diff --git a/src/utils.rs b/src/utils.rs index ff24a26..82b2ddf 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,6 @@ use lsp_server::ResponseError; -use crate::{configuration::ChatMessage, memory_backends::Prompt}; +use crate::{config::ChatMessage, memory_backends::Prompt}; pub trait ToResponseError { fn to_response_error(&self, code: i32) -> ResponseError;