Cleaned so much stuff up add tracing add chat formatting

This commit is contained in:
Silas Marvin
2024-03-03 14:40:42 -08:00
parent 28b7b1b74e
commit 6627da705e
12 changed files with 353 additions and 225 deletions

View File

@@ -4,13 +4,14 @@ use llama_cpp_2::{
ggml_time_us,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel},
model::{params::LlamaModelParams, AddBos, LlamaChatMessage, LlamaModel},
token::data_array::LlamaTokenDataArray,
};
use once_cell::sync::Lazy;
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
use tracing::{debug, info, instrument};
use crate::configuration::Kwargs;
use crate::configuration::{ChatMessage, Kwargs};
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
@@ -20,6 +21,7 @@ pub struct Model {
}
impl Model {
#[instrument]
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
// Get n_gpu_layers if set in kwargs
// As a default we set it to 1000, which should put all layers on the GPU
@@ -41,9 +43,8 @@ impl Model {
};
// Load the model
eprintln!("SETTING MODEL AT PATH: {:?}", model_path);
debug!("Loading model at path: {:?}", model_path);
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
eprintln!("\nMODEL SET\n");
// Get n_ctx if set in kwargs
// As a default we set it to 2048
@@ -60,6 +61,7 @@ impl Model {
Ok(Model { model, n_ctx })
}
#[instrument(skip(self))]
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
// initialize the context
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
@@ -77,9 +79,7 @@ impl Model {
let n_cxt = ctx.n_ctx() as usize;
let n_kv_req = tokens_list.len() + max_new_tokens;
eprintln!(
"n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}"
);
info!("n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if n_kv_req > n_cxt {
@@ -132,14 +132,29 @@ impl Model {
let t_main_end = ggml_time_us();
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
eprintln!(
info!(
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
n_decode,
duration.as_secs_f32(),
n_decode as f32 / duration.as_secs_f32()
);
eprintln!("{}", ctx.timings());
info!("{}", ctx.timings());
Ok(output.join(""))
}
#[instrument(skip(self))]
pub fn apply_chat_template(
&self,
messages: Vec<ChatMessage>,
template: Option<String>,
) -> anyhow::Result<String> {
let llama_chat_messages = messages
.into_iter()
.map(|c| LlamaChatMessage::new(c.role, c.content))
.collect::<Result<Vec<LlamaChatMessage>, _>>()?;
Ok(self
.model
.apply_chat_template(template, llama_chat_messages, true)?)
}
}