diff --git a/src/config.rs b/src/config.rs index 89f3c68..fcb5def 100644 --- a/src/config.rs +++ b/src/config.rs @@ -81,13 +81,6 @@ pub struct FileStore { pub crawl: bool, } -#[derive(Clone, Debug, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct Model { - pub repository: String, - pub name: Option, -} - const fn n_gpu_layers_default() -> u32 { 1000 } @@ -106,6 +99,7 @@ pub struct MistralFIM { pub fim_endpoint: Option, // The model name pub model: String, + // The maximum requests per second #[serde(default = "max_requests_per_second_default")] pub max_requests_per_second: f32, } @@ -113,13 +107,17 @@ pub struct MistralFIM { #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct LLaMACPP { - // The model to use - #[serde(flatten)] - pub model: Model, + // Which model to use + pub repository: Option, + pub name: Option, + pub file_path: Option, + // The layers to put on the GPU #[serde(default = "n_gpu_layers_default")] pub n_gpu_layers: u32, + // The context size #[serde(default = "n_ctx_default")] pub n_ctx: u32, + // The maximum requests per second #[serde(default = "max_requests_per_second_default")] pub max_requests_per_second: f32, } @@ -129,6 +127,7 @@ pub struct LLaMACPP { pub struct OpenAI { // The auth token env var name pub auth_token_env_var_name: Option, + // The auth token pub auth_token: Option, // The completions endpoint pub completions_endpoint: Option, diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index 3e5e758..a1e1ac1 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -9,7 +9,6 @@ use crate::{ }, utils::format_chat_messages, }; -use anyhow::Context; use hf_hub::api::sync::ApiBuilder; use serde::Deserialize; use serde_json::Value; @@ -41,15 +40,22 @@ pub struct LLaMACPP { impl LLaMACPP { #[instrument] pub fn new(configuration: config::LLaMACPP) -> anyhow::Result { - let api = ApiBuilder::new().with_progress(true).build()?; - let name = configuration - .model - .name - .as_ref() - .context("Please set `name` to use LLaMA.cpp")?; - error!("Loading in: {} - {}\nIf this model has not been loaded before it may take a few minutes to download it. Please hangtight.", configuration.model.repository, name); - let repo = api.model(configuration.model.repository.to_owned()); - let model_path = repo.get(name)?; + let model_path = match ( + &configuration.file_path, + &configuration.repository, + &configuration.name, + ) { + (Some(file_path), _, _) => std::path::PathBuf::from(file_path), + (_, Some(repository), Some(name)) => { + let api = ApiBuilder::new().with_progress(true).build()?; + error!("Loading in: {} - {}\nIf this model has not been loaded before it may take a few minutes to download it. Please hangtight.", repository, name); + let repo = api.model(repository.clone()); + repo.get(&name)? + } + _ => anyhow::bail!( + "To use llama.cpp provide either `file_path` or `repository` and `name`" + ), + }; let model = Model::new(model_path, &configuration)?; Ok(Self { model }) }