use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; pub type Kwargs = HashMap; #[derive(Debug, Clone, Deserialize)] pub enum ValidMemoryBackend { #[serde(rename = "file_store")] FileStore(FileStore), #[serde(rename = "postgresml")] PostgresML(PostgresML), } #[derive(Debug, Clone, Deserialize)] #[serde(tag = "type")] pub enum ValidModel { #[cfg(feature = "llamacpp")] #[serde(rename = "llamacpp")] LLaMACPP(LLaMACPP), #[serde(rename = "openai")] OpenAI(OpenAI), #[serde(rename = "anthropic")] Anthropic(Anthropic), } #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct ChatMessage { pub role: String, pub content: String, } #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub struct Chat { pub completion: Option>, pub generation: Option>, pub chat_template: Option, pub chat_format: Option, } #[derive(Clone, Debug, Deserialize)] #[allow(clippy::upper_case_acronyms)] #[serde(deny_unknown_fields)] pub struct FIM { pub start: String, pub middle: String, pub end: String, } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct PostgresML { pub database_url: Option, #[serde(default)] pub crawl: bool, } #[derive(Clone, Debug, Deserialize, Default)] #[serde(deny_unknown_fields)] pub struct FileStore { #[serde(default)] pub crawl: bool, } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct Model { pub repository: String, pub name: Option, } const fn n_gpu_layers_default() -> u32 { 1000 } const fn n_ctx_default() -> u32 { 1000 } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct LLaMACPP { // The model to use #[serde(flatten)] pub model: Model, #[serde(default = "n_gpu_layers_default")] pub n_gpu_layers: u32, #[serde(default = "n_ctx_default")] pub n_ctx: u32, } const fn api_max_requests_per_second_default() -> f32 { 0.5 } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct OpenAI { // The auth token env var name pub auth_token_env_var_name: Option, pub auth_token: Option, // The completions endpoint pub completions_endpoint: Option, // The chat endpoint pub chat_endpoint: Option, // The maximum requests per second #[serde(default = "api_max_requests_per_second_default")] pub max_requests_per_second: f32, // The model name pub model: String, } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct Anthropic { // The auth token env var name pub auth_token_env_var_name: Option, pub auth_token: Option, // The completions endpoint pub completions_endpoint: Option, // The chat endpoint pub chat_endpoint: Option, // The maximum requests per second #[serde(default = "api_max_requests_per_second_default")] pub max_requests_per_second: f32, // The model name pub model: String, } #[derive(Clone, Debug, Deserialize)] pub struct Completion { // The model key to use pub model: String, // Args are deserialized by the backend using them #[serde(default)] pub parameters: Kwargs, } #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct ValidConfig { pub memory: ValidMemoryBackend, pub models: HashMap, pub completion: Option, } #[derive(Clone, Debug, Deserialize, Default)] pub struct ValidClientParams { #[serde(alias = "rootURI")] _root_uri: Option, _workspace_folders: Option>, } #[derive(Clone, Debug)] pub struct Config { pub config: ValidConfig, _client_params: ValidClientParams, } impl Config { pub fn new(mut args: Value) -> Result { // Validate that the models specfied are there so we can unwrap let configuration_args = args .as_object_mut() .context("Server configuration must be a JSON object")? .remove("initializationOptions"); let valid_args = match configuration_args { Some(configuration_args) => serde_json::from_value(configuration_args)?, None => anyhow::bail!("lsp-ai does not currently provide a default configuration. Please pass a configuration. See https://github.com/SilasMarvin/lsp-ai for configuration options and examples"), }; let client_params: ValidClientParams = serde_json::from_value(args)?; Ok(Self { config: valid_args, _client_params: client_params, }) } /////////////////////////////////////// // Helpers for the backends /////////// /////////////////////////////////////// pub fn is_completions_enabled(&self) -> bool { self.config.completion.is_some() } pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result { match &self .config .models .get( &self .config .completion .as_ref() .context("Completions is not enabled")? .model, ) .with_context(|| { format!( "`{}` model not found in `models` config", &self.config.completion.as_ref().unwrap().model ) })? { #[cfg(feature = "llamacpp")] ValidModel::LLaMACPP(_) => Ok(1.), ValidModel::OpenAI(openai) => Ok(openai.max_requests_per_second), ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second), } } } // This makes testing much easier. #[cfg(test)] impl Config { pub fn default_with_file_store_without_models() -> Self { Self { config: ValidConfig { memory: ValidMemoryBackend::FileStore(FileStore { crawl: false }), models: HashMap::new(), completion: None, }, _client_params: ValidClientParams { _root_uri: None, _workspace_folders: None, }, } } } #[cfg(test)] mod test { use super::*; use serde_json::json; #[test] #[cfg(feature = "llamacpp")] fn llama_cpp_config() { let args = json!({ "initializationOptions": { "memory": { "file_store": {} }, "models": { "model1": { "type": "llamacpp", "repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF", "name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf", "n_ctx": 2048, "n_gpu_layers": 35 } }, "completion": { "model": "model1", "parameters": { "fim": { "start": "", "middle": "", "end": "" }, "max_context": 1024, "max_new_tokens": 32, } } } }); Config::new(args).unwrap(); } #[test] fn openai_config() { let args = json!({ "initializationOptions": { "memory": { "file_store": {} }, "models": { "model1": { "type": "openai", "completions_endpoint": "https://api.fireworks.ai/inference/v1/completions", "model": "accounts/fireworks/models/llama-v2-34b-code", "auth_token_env_var_name": "FIREWORKS_API_KEY", }, }, "completion": { "model": "model1", "parameters": { "messages": [ { "role": "system", "content": "Test", }, { "role": "user", "content": "Test {CONTEXT} - {CODE}" } ], "max_new_tokens": 32, } } } }); Config::new(args).unwrap(); } #[test] fn anthropic_config() { let args = json!({ "initializationOptions": { "memory": { "file_store": {} }, "models": { "model1": { "type": "anthropic", "completions_endpoint": "https://api.anthropic.com/v1/messages", "model": "claude-3-haiku-20240307", "auth_token_env_var_name": "ANTHROPIC_API_KEY", }, }, "completion": { "model": "model1", "parameters": { "system": "Test", "messages": [ { "role": "user", "content": "Test {CONTEXT} - {CODE}" } ], "max_new_tokens": 32, } } } }); Config::new(args).unwrap(); } }