mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-19 07:24:24 +01:00
178 lines
5.9 KiB
Rust
178 lines
5.9 KiB
Rust
use anyhow::Context;
|
|
use llama_cpp_2::{
|
|
context::params::LlamaContextParams,
|
|
ggml_time_us,
|
|
llama_backend::LlamaBackend,
|
|
llama_batch::LlamaBatch,
|
|
model::{params::LlamaModelParams, AddBos, LlamaChatMessage, LlamaModel},
|
|
token::data_array::LlamaTokenDataArray,
|
|
};
|
|
use once_cell::sync::Lazy;
|
|
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
|
|
use tracing::{debug, info, instrument};
|
|
|
|
use crate::configuration::{ChatMessage, Kwargs};
|
|
|
|
static BACKEND: Lazy<LlamaBackend> = Lazy::new(|| LlamaBackend::init().unwrap());
|
|
|
|
pub struct Model {
|
|
model: LlamaModel,
|
|
n_ctx: NonZeroU32,
|
|
}
|
|
|
|
impl Model {
|
|
#[instrument]
|
|
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
|
|
// Get n_gpu_layers if set in kwargs
|
|
// As a default we set it to 1000, which should put all layers on the GPU
|
|
let n_gpu_layers = kwargs
|
|
.get("n_gpu_layers")
|
|
.map(|u| anyhow::Ok(u.as_u64().context("n_gpu_layers must be a number")? as u32))
|
|
.unwrap_or_else(|| Ok(1000))?;
|
|
|
|
// Initialize the model_params
|
|
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
|
|
|
|
// Load the model
|
|
eprintln!();
|
|
eprintln!();
|
|
eprintln!();
|
|
eprintln!();
|
|
eprintln!();
|
|
eprintln!();
|
|
debug!("Loading model at path: {:?}", model_path);
|
|
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
|
eprintln!();
|
|
eprintln!("LOADED THE MODEL");
|
|
eprintln!();
|
|
eprintln!();
|
|
eprintln!();
|
|
eprintln!();
|
|
|
|
// Get n_ctx if set in kwargs
|
|
// As a default we set it to 2048
|
|
let n_ctx = kwargs
|
|
.get("n_ctx")
|
|
.map(|u| {
|
|
anyhow::Ok(NonZeroU32::new(
|
|
u.as_u64().context("n_ctx must be a number")? as u32,
|
|
))
|
|
})
|
|
.unwrap_or_else(|| Ok(NonZeroU32::new(2048)))?
|
|
.context("n_ctx must not be zero")?;
|
|
|
|
Ok(Model { model, n_ctx })
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
|
// initialize the context
|
|
let ctx_params = LlamaContextParams::default()
|
|
.with_n_ctx(Some(self.n_ctx))
|
|
.with_n_batch(self.n_ctx.get());
|
|
|
|
let mut ctx = self
|
|
.model
|
|
.new_context(&BACKEND, ctx_params)
|
|
.with_context(|| "unable to create the llama_context")?;
|
|
|
|
let tokens_list = self
|
|
.model
|
|
.str_to_token(prompt, AddBos::Always)
|
|
.with_context(|| format!("failed to tokenize {}", prompt))?;
|
|
|
|
let n_cxt = ctx.n_ctx() as usize;
|
|
let n_kv_req = tokens_list.len() + max_new_tokens;
|
|
|
|
info!("n_len / max_new_tokens = {max_new_tokens}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
|
|
|
|
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
|
if n_kv_req > n_cxt {
|
|
anyhow::bail!(
|
|
"n_kv_req > n_ctx, the required kv cache size is not big enough
|
|
either reduce max_new_tokens or increase n_ctx"
|
|
)
|
|
}
|
|
|
|
let mut batch = LlamaBatch::new(n_cxt, 1);
|
|
|
|
let last_index: i32 = (tokens_list.len() - 1) as i32;
|
|
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
|
|
// llama_decode will output logits only for the last token of the prompt
|
|
let is_last = i == last_index;
|
|
batch.add(token, i, &[0], is_last)?;
|
|
}
|
|
|
|
ctx.decode(&mut batch)
|
|
.with_context(|| "llama_decode() failed")?;
|
|
|
|
// main loop
|
|
let n_start = batch.n_tokens();
|
|
let mut output: Vec<String> = vec![];
|
|
let mut n_cur = n_start;
|
|
let mut n_decode = 0;
|
|
let t_main_start = ggml_time_us();
|
|
while (n_cur as usize) <= (n_start as usize + max_new_tokens) {
|
|
// sample the next token
|
|
{
|
|
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
|
|
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
|
|
|
|
// sample the most likely token
|
|
let new_token_id = ctx.sample_token_greedy(candidates_p);
|
|
|
|
// is it an end of stream?
|
|
if new_token_id == self.model.token_eos() {
|
|
break;
|
|
}
|
|
|
|
output.push(self.model.token_to_str(new_token_id)?);
|
|
batch.clear();
|
|
batch.add(new_token_id, n_cur, &[0], true)?;
|
|
}
|
|
n_cur += 1;
|
|
ctx.decode(&mut batch).with_context(|| "failed to eval")?;
|
|
n_decode += 1;
|
|
}
|
|
|
|
let t_main_end = ggml_time_us();
|
|
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
|
|
info!(
|
|
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
|
|
n_decode,
|
|
duration.as_secs_f32(),
|
|
n_decode as f32 / duration.as_secs_f32()
|
|
);
|
|
info!("{}", ctx.timings());
|
|
|
|
Ok(output.join(""))
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
pub fn apply_chat_template(
|
|
&self,
|
|
messages: Vec<ChatMessage>,
|
|
template: Option<String>,
|
|
) -> anyhow::Result<String> {
|
|
let llama_chat_messages = messages
|
|
.into_iter()
|
|
.map(|c| LlamaChatMessage::new(c.role, c.content))
|
|
.collect::<Result<Vec<LlamaChatMessage>, _>>()?;
|
|
Ok(self
|
|
.model
|
|
.apply_chat_template(template, llama_chat_messages, true)?)
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
pub fn get_eos_token(&self) -> anyhow::Result<String> {
|
|
let token = self.model.token_eos();
|
|
Ok(self.model.token_to_str(token)?)
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
pub fn get_bos_token(&self) -> anyhow::Result<String> {
|
|
let token = self.model.token_bos();
|
|
Ok(self.model.token_to_str(token)?)
|
|
}
|
|
}
|