Added config option for llama_cpp models

This commit is contained in:
Silas Marvin
2024-06-09 08:05:36 -07:00
parent f84b5fe2be
commit ef86f20667
2 changed files with 25 additions and 20 deletions

View File

@@ -81,13 +81,6 @@ pub struct FileStore {
pub crawl: bool,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Model {
pub repository: String,
pub name: Option<String>,
}
const fn n_gpu_layers_default() -> u32 {
1000
}
@@ -106,6 +99,7 @@ pub struct MistralFIM {
pub fim_endpoint: Option<String>,
// The model name
pub model: String,
// The maximum requests per second
#[serde(default = "max_requests_per_second_default")]
pub max_requests_per_second: f32,
}
@@ -113,13 +107,17 @@ pub struct MistralFIM {
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct LLaMACPP {
// The model to use
#[serde(flatten)]
pub model: Model,
// Which model to use
pub repository: Option<String>,
pub name: Option<String>,
pub file_path: Option<String>,
// The layers to put on the GPU
#[serde(default = "n_gpu_layers_default")]
pub n_gpu_layers: u32,
// The context size
#[serde(default = "n_ctx_default")]
pub n_ctx: u32,
// The maximum requests per second
#[serde(default = "max_requests_per_second_default")]
pub max_requests_per_second: f32,
}
@@ -129,6 +127,7 @@ pub struct LLaMACPP {
pub struct OpenAI {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
// The auth token
pub auth_token: Option<String>,
// The completions endpoint
pub completions_endpoint: Option<String>,

View File

@@ -9,7 +9,6 @@ use crate::{
},
utils::format_chat_messages,
};
use anyhow::Context;
use hf_hub::api::sync::ApiBuilder;
use serde::Deserialize;
use serde_json::Value;
@@ -41,15 +40,22 @@ pub struct LLaMACPP {
impl LLaMACPP {
#[instrument]
pub fn new(configuration: config::LLaMACPP) -> anyhow::Result<Self> {
let api = ApiBuilder::new().with_progress(true).build()?;
let name = configuration
.model
.name
.as_ref()
.context("Please set `name` to use LLaMA.cpp")?;
error!("Loading in: {} - {}\nIf this model has not been loaded before it may take a few minutes to download it. Please hangtight.", configuration.model.repository, name);
let repo = api.model(configuration.model.repository.to_owned());
let model_path = repo.get(name)?;
let model_path = match (
&configuration.file_path,
&configuration.repository,
&configuration.name,
) {
(Some(file_path), _, _) => std::path::PathBuf::from(file_path),
(_, Some(repository), Some(name)) => {
let api = ApiBuilder::new().with_progress(true).build()?;
error!("Loading in: {} - {}\nIf this model has not been loaded before it may take a few minutes to download it. Please hangtight.", repository, name);
let repo = api.model(repository.clone());
repo.get(&name)?
}
_ => anyhow::bail!(
"To use llama.cpp provide either `file_path` or `repository` and `name`"
),
};
let model = Model::new(model_path, &configuration)?;
Ok(Self { model })
}