mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-23 09:24:23 +01:00
148 lines
5.2 KiB
Rust
148 lines
5.2 KiB
Rust
use anyhow::Context;
|
|
use hf_hub::api::sync::Api;
|
|
use tracing::{debug, instrument};
|
|
|
|
use super::TransformerBackend;
|
|
use crate::{
|
|
configuration::Configuration,
|
|
memory_backends::Prompt,
|
|
utils::format_chat_messages,
|
|
worker::{
|
|
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
|
},
|
|
};
|
|
|
|
mod model;
|
|
use model::Model;
|
|
|
|
pub struct LlamaCPP {
|
|
model: Model,
|
|
configuration: Configuration,
|
|
}
|
|
|
|
impl LlamaCPP {
|
|
#[instrument]
|
|
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
|
|
let api = Api::new()?;
|
|
let model = configuration.get_model()?;
|
|
let name = model
|
|
.name
|
|
.as_ref()
|
|
.context("Model `name` is required when using GGUF models")?;
|
|
let repo = api.model(model.repository.to_owned());
|
|
let model_path = repo.get(&name)?;
|
|
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
|
|
Ok(Self {
|
|
model,
|
|
configuration,
|
|
})
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
fn get_prompt_string(&self, prompt: &Prompt) -> anyhow::Result<String> {
|
|
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
|
|
Ok(match self.configuration.get_chat()? {
|
|
Some(c) => {
|
|
if let Some(completion_messages) = &c.completion {
|
|
let chat_messages = format_chat_messages(completion_messages, prompt);
|
|
self.model
|
|
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
|
|
} else {
|
|
prompt.code.to_owned()
|
|
}
|
|
}
|
|
None => prompt.code.to_owned(),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl TransformerBackend for LlamaCPP {
|
|
#[instrument(skip(self))]
|
|
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
|
let prompt = self.get_prompt_string(prompt)?;
|
|
// debug!("Prompt string for LLM: {}", prompt);
|
|
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
|
self.model
|
|
.complete(&prompt, max_new_tokens)
|
|
.map(|insert_text| DoCompletionResponse { insert_text })
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
|
let prompt = self.get_prompt_string(prompt)?;
|
|
// debug!("Prompt string for LLM: {}", prompt);
|
|
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
|
self.model
|
|
.complete(&prompt, max_new_tokens)
|
|
.map(|generated_text| DoGenerateResponse { generated_text })
|
|
}
|
|
|
|
#[instrument(skip(self))]
|
|
fn do_generate_stream(
|
|
&self,
|
|
_request: &GenerateStreamRequest,
|
|
) -> anyhow::Result<DoGenerateStreamResponse> {
|
|
unimplemented!()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use serde_json::json;
|
|
|
|
#[test]
|
|
fn test_gguf() {
|
|
let args = json!({
|
|
"initializationOptions": {
|
|
"memory": {
|
|
"file_store": {}
|
|
},
|
|
"macos": {
|
|
"model_gguf": {
|
|
"repository": "stabilityai/stable-code-3b",
|
|
"name": "stable-code-3b-Q5_K_M.gguf",
|
|
"max_new_tokens": {
|
|
"completion": 32,
|
|
"generation": 256,
|
|
},
|
|
// "fim": {
|
|
// "start": "",
|
|
// "middle": "",
|
|
// "end": ""
|
|
// },
|
|
"chat": {
|
|
"completion": [
|
|
{
|
|
"role": "system",
|
|
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"message": "Complete the following code: \n\n{code}"
|
|
}
|
|
],
|
|
"generation": [
|
|
{
|
|
"role": "system",
|
|
"message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"message": "Complete the following code: \n\n{code}"
|
|
}
|
|
]
|
|
},
|
|
"n_ctx": 2048,
|
|
"n_gpu_layers": 1000,
|
|
}
|
|
},
|
|
}
|
|
});
|
|
let configuration = Configuration::new(args).unwrap();
|
|
let _model = LlamaCPP::new(configuration).unwrap();
|
|
// let output = model.do_completion("def fibon").unwrap();
|
|
// println!("{}", output.insert_text);
|
|
}
|
|
}
|