mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-19 07:24:24 +01:00
Cleaned up llamacpp stuff
This commit is contained in:
@@ -34,6 +34,7 @@ async-trait = "0.1.78"
|
|||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
cublas = ["llama-cpp-2/cublas"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
assert_cmd = "2.0.14"
|
assert_cmd = "2.0.14"
|
||||||
|
|||||||
@@ -218,19 +218,18 @@ pub struct ValidConfiguration {
|
|||||||
#[derive(Clone, Debug, Deserialize, Default)]
|
#[derive(Clone, Debug, Deserialize, Default)]
|
||||||
pub struct ValidClientParams {
|
pub struct ValidClientParams {
|
||||||
#[serde(alias = "rootURI")]
|
#[serde(alias = "rootURI")]
|
||||||
root_uri: Option<String>,
|
_root_uri: Option<String>,
|
||||||
workspace_folders: Option<Vec<String>>,
|
_workspace_folders: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Configuration {
|
pub struct Configuration {
|
||||||
pub config: ValidConfiguration,
|
pub config: ValidConfiguration,
|
||||||
client_params: ValidClientParams,
|
_client_params: ValidClientParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Configuration {
|
impl Configuration {
|
||||||
pub fn new(mut args: Value) -> Result<Self> {
|
pub fn new(mut args: Value) -> Result<Self> {
|
||||||
eprintln!("\n\n{}\n\n", args.to_string());
|
|
||||||
let configuration_args = args
|
let configuration_args = args
|
||||||
.as_object_mut()
|
.as_object_mut()
|
||||||
.context("Server configuration must be a JSON object")?
|
.context("Server configuration must be a JSON object")?
|
||||||
@@ -242,7 +241,7 @@ impl Configuration {
|
|||||||
let client_params: ValidClientParams = serde_json::from_value(args)?;
|
let client_params: ValidClientParams = serde_json::from_value(args)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config: valid_args,
|
config: valid_args,
|
||||||
client_params,
|
_client_params: client_params,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ fn main() -> Result<()> {
|
|||||||
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||||
// Build our configuration
|
// Build our configuration
|
||||||
let configuration = Configuration::new(args)?;
|
let configuration = Configuration::new(args)?;
|
||||||
|
eprintln!("GOT THE CONFIG: {:?}", configuration);
|
||||||
|
|
||||||
// Wrap the connection for sharing between threads
|
// Wrap the connection for sharing between threads
|
||||||
let connection = Arc::new(connection);
|
let connection = Arc::new(connection);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use indexmap::IndexSet;
|
|||||||
use lsp_types::TextDocumentPositionParams;
|
use lsp_types::TextDocumentPositionParams;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::collections::HashMap;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -14,29 +14,16 @@ use crate::{
|
|||||||
use super::{MemoryBackend, Prompt, PromptForType};
|
use super::{MemoryBackend, Prompt, PromptForType};
|
||||||
|
|
||||||
pub struct FileStore {
|
pub struct FileStore {
|
||||||
crawl: bool,
|
_crawl: bool,
|
||||||
configuration: Configuration,
|
configuration: Configuration,
|
||||||
file_map: Mutex<HashMap<String, Rope>>,
|
file_map: Mutex<HashMap<String, Rope>>,
|
||||||
accessed_files: Mutex<IndexSet<String>>,
|
accessed_files: Mutex<IndexSet<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Put some thought into the crawling here. Do we want to have a crawl option where it tries to crawl through all relevant
|
|
||||||
// files and then when asked for context it loads them in by the most recently accessed? That seems kind of silly honestly, but I could see
|
|
||||||
// how users who want to use models with massive context lengths would just want their entire project as context for generation tasks
|
|
||||||
// I'm not sure yet, this is something I need to think through more
|
|
||||||
|
|
||||||
// Ok here are some more ideas
|
|
||||||
// We take a crawl arg which is a bool of true or false for file_store
|
|
||||||
// If true we crawl until we get to the max_context_length and then we stop crawling
|
|
||||||
// We keep track of the last opened / changed files, and prioritize those when building the context for our llms
|
|
||||||
|
|
||||||
// For memory backends like PostgresML, they will need to take some kind of max_context_length to crawl or something.
|
|
||||||
// In other words, there needs to be some specification for how much they should be crawling because the limiting happens in the vector_recall
|
|
||||||
impl FileStore {
|
impl FileStore {
|
||||||
pub fn new(file_store_config: configuration::FileStore, configuration: Configuration) -> Self {
|
pub fn new(file_store_config: configuration::FileStore, configuration: Configuration) -> Self {
|
||||||
// TODO: maybe crawl
|
|
||||||
Self {
|
Self {
|
||||||
crawl: file_store_config.crawl,
|
_crawl: file_store_config.crawl,
|
||||||
configuration,
|
configuration,
|
||||||
file_map: Mutex::new(HashMap::new()),
|
file_map: Mutex::new(HashMap::new()),
|
||||||
accessed_files: Mutex::new(IndexSet::new()),
|
accessed_files: Mutex::new(IndexSet::new()),
|
||||||
@@ -45,7 +32,7 @@ impl FileStore {
|
|||||||
|
|
||||||
pub fn new_without_crawl(configuration: Configuration) -> Self {
|
pub fn new_without_crawl(configuration: Configuration) -> Self {
|
||||||
Self {
|
Self {
|
||||||
crawl: false,
|
_crawl: false,
|
||||||
configuration,
|
configuration,
|
||||||
file_map: Mutex::new(HashMap::new()),
|
file_map: Mutex::new(HashMap::new()),
|
||||||
accessed_files: Mutex::new(IndexSet::new()),
|
accessed_files: Mutex::new(IndexSet::new()),
|
||||||
@@ -135,8 +122,7 @@ impl FileStore {
|
|||||||
.unwrap_or(false),
|
.unwrap_or(false),
|
||||||
};
|
};
|
||||||
|
|
||||||
// We only want to do FIM if the user has enabled it, the cursor is not at the end of the file,
|
// We only want to do FIM if the user has enabled it and the user has not enabled chat
|
||||||
// and the user has not enabled chat
|
|
||||||
Ok(match (is_chat_enabled, self.configuration.get_fim()) {
|
Ok(match (is_chat_enabled, self.configuration.get_fim()) {
|
||||||
r @ (true, _) | r @ (false, Some(_))
|
r @ (true, _) | r @ (false, Some(_))
|
||||||
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
if is_chat_enabled || rope.len_chars() != cursor_index =>
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use anyhow::Context;
|
|||||||
use lsp_types::TextDocumentPositionParams;
|
use lsp_types::TextDocumentPositionParams;
|
||||||
use pgml::{Collection, Pipeline};
|
use pgml::{Collection, Pipeline};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::{runtime::Runtime, time};
|
use tokio::time;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -22,7 +22,6 @@ pub struct PostgresML {
|
|||||||
file_store: FileStore,
|
file_store: FileStore,
|
||||||
collection: Collection,
|
collection: Collection,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
runtime: Runtime,
|
|
||||||
debounce_tx: Sender<String>,
|
debounce_tx: Sender<String>,
|
||||||
added_pipeline: bool,
|
added_pipeline: bool,
|
||||||
}
|
}
|
||||||
@@ -62,12 +61,11 @@ impl PostgresML {
|
|||||||
.into(),
|
.into(),
|
||||||
),
|
),
|
||||||
)?;
|
)?;
|
||||||
// Create our own runtime
|
// Setup up a debouncer for changed text documents
|
||||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||||
.worker_threads(2)
|
.worker_threads(2)
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()?;
|
.build()?;
|
||||||
// Setup up a debouncer for changed text documents
|
|
||||||
let mut task_collection = collection.clone();
|
let mut task_collection = collection.clone();
|
||||||
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
|
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
|
||||||
runtime.spawn(async move {
|
runtime.spawn(async move {
|
||||||
@@ -106,15 +104,11 @@ impl PostgresML {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// TODO: Maybe
|
|
||||||
// Need to crawl the root path and or workspace folders
|
|
||||||
// Or set some kind of did crawl for it
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
configuration,
|
configuration,
|
||||||
file_store,
|
file_store,
|
||||||
collection,
|
collection,
|
||||||
pipeline,
|
pipeline,
|
||||||
runtime,
|
|
||||||
debounce_tx,
|
debounce_tx,
|
||||||
added_pipeline: false,
|
added_pipeline: false,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -41,17 +41,13 @@ impl Anthropic {
|
|||||||
messages: Vec<ChatMessage>,
|
messages: Vec<ChatMessage>,
|
||||||
max_tokens: usize,
|
max_tokens: usize,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
eprintln!(
|
|
||||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
|
||||||
messages
|
|
||||||
);
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||||
std::env::var(env_var_name)?
|
std::env::var(env_var_name)?
|
||||||
} else if let Some(token) = &self.configuration.auth_token {
|
} else if let Some(token) = &self.configuration.auth_token {
|
||||||
token.to_string()
|
token.to_string()
|
||||||
} else {
|
} else {
|
||||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `transformer->anthropic` to use an Anthropic");
|
||||||
};
|
};
|
||||||
let res: AnthropicChatResponse = client
|
let res: AnthropicChatResponse = client
|
||||||
.post(
|
.post(
|
||||||
@@ -111,7 +107,7 @@ impl TransformerBackend for Anthropic {
|
|||||||
let insert_text = match &self.configuration.chat.completion {
|
let insert_text = match &self.configuration.chat.completion {
|
||||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||||
None => {
|
None => {
|
||||||
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
|
anyhow::bail!("Please set `transformer->anthropic->chat->completion` messages")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(DoCompletionResponse { insert_text })
|
Ok(DoCompletionResponse { insert_text })
|
||||||
@@ -124,7 +120,7 @@ impl TransformerBackend for Anthropic {
|
|||||||
let generated_text = match &self.configuration.chat.generation {
|
let generated_text = match &self.configuration.chat.generation {
|
||||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||||
None => {
|
None => {
|
||||||
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
|
anyhow::bail!("Please set `transformer->anthropic->chat->generation` messages")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(DoGenerateResponse { generated_text })
|
Ok(DoGenerateResponse { generated_text })
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use hf_hub::api::sync::ApiBuilder;
|
use hf_hub::api::sync::ApiBuilder;
|
||||||
use tracing::{debug, instrument};
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
configuration::{self},
|
configuration::{self},
|
||||||
@@ -17,12 +17,12 @@ use model::Model;
|
|||||||
|
|
||||||
use super::TransformerBackend;
|
use super::TransformerBackend;
|
||||||
|
|
||||||
pub struct LlamaCPP {
|
pub struct LLaMACPP {
|
||||||
model: Model,
|
model: Model,
|
||||||
configuration: configuration::LLaMACPP,
|
configuration: configuration::LLaMACPP,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaCPP {
|
impl LLaMACPP {
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub fn new(configuration: configuration::LLaMACPP) -> anyhow::Result<Self> {
|
pub fn new(configuration: configuration::LLaMACPP) -> anyhow::Result<Self> {
|
||||||
let api = ApiBuilder::new().with_progress(true).build()?;
|
let api = ApiBuilder::new().with_progress(true).build()?;
|
||||||
@@ -30,9 +30,9 @@ impl LlamaCPP {
|
|||||||
.model
|
.model
|
||||||
.name
|
.name
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.context("Model `name` is required when using GGUF models")?;
|
.context("Please set `transformer->llamacpp->name` to use LLaMA.cpp")?;
|
||||||
let repo = api.model(configuration.model.repository.to_owned());
|
let repo = api.model(configuration.model.repository.to_owned());
|
||||||
let model_path = repo.get(&name)?;
|
let model_path = repo.get(name)?;
|
||||||
let model = Model::new(model_path, &configuration.kwargs)?;
|
let model = Model::new(model_path, &configuration.kwargs)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
@@ -50,7 +50,7 @@ impl LlamaCPP {
|
|||||||
if let Some(chat_template) = &c.chat_template {
|
if let Some(chat_template) = &c.chat_template {
|
||||||
let bos_token = self.model.get_bos_token()?;
|
let bos_token = self.model.get_bos_token()?;
|
||||||
let eos_token = self.model.get_eos_token()?;
|
let eos_token = self.model.get_eos_token()?;
|
||||||
apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)?
|
apply_chat_template(chat_template, chat_messages, &bos_token, &eos_token)?
|
||||||
} else {
|
} else {
|
||||||
self.model.apply_chat_template(chat_messages, None)?
|
self.model.apply_chat_template(chat_messages, None)?
|
||||||
}
|
}
|
||||||
@@ -64,12 +64,10 @@ impl LlamaCPP {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl TransformerBackend for LlamaCPP {
|
impl TransformerBackend for LLaMACPP {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||||
// let prompt = self.get_prompt_string(prompt)?;
|
let prompt = self.get_prompt_string(prompt)?;
|
||||||
let prompt = &prompt.code;
|
|
||||||
debug!("Prompt string for LLM: {}", prompt);
|
|
||||||
let max_new_tokens = self.configuration.max_tokens.completion;
|
let max_new_tokens = self.configuration.max_tokens.completion;
|
||||||
self.model
|
self.model
|
||||||
.complete(&prompt, max_new_tokens)
|
.complete(&prompt, max_new_tokens)
|
||||||
@@ -78,9 +76,7 @@ impl TransformerBackend for LlamaCPP {
|
|||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||||
// let prompt = self.get_prompt_string(prompt)?;
|
let prompt = self.get_prompt_string(prompt)?;
|
||||||
// debug!("Prompt string for LLM: {}", prompt);
|
|
||||||
let prompt = &prompt.code;
|
|
||||||
let max_new_tokens = self.configuration.max_tokens.completion;
|
let max_new_tokens = self.configuration.max_tokens.completion;
|
||||||
self.model
|
self.model
|
||||||
.complete(&prompt, max_new_tokens)
|
.complete(&prompt, max_new_tokens)
|
||||||
@@ -96,60 +92,55 @@ impl TransformerBackend for LlamaCPP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[cfg(test)]
|
#[cfg(test)]
|
||||||
// mod tests {
|
mod test {
|
||||||
// use super::*;
|
use super::*;
|
||||||
// use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
// #[test]
|
#[tokio::test]
|
||||||
// fn test_gguf() {
|
async fn llama_cpp_do_completion() -> anyhow::Result<()> {
|
||||||
// let args = json!({
|
let configuration: configuration::LLaMACPP = serde_json::from_value(json!({
|
||||||
// "initializationOptions": {
|
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
|
||||||
// "memory": {
|
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
|
||||||
// "file_store": {}
|
"max_new_tokens": {
|
||||||
// },
|
"completion": 32,
|
||||||
// "model_gguf": {
|
"generation": 256,
|
||||||
// "repository": "stabilityai/stable-code-3b",
|
},
|
||||||
// "name": "stable-code-3b-Q5_K_M.gguf",
|
"fim": {
|
||||||
// "max_new_tokens": {
|
"start": "<fim_prefix>",
|
||||||
// "completion": 32,
|
"middle": "<fim_suffix>",
|
||||||
// "generation": 256,
|
"end": "<fim_middle>"
|
||||||
// },
|
},
|
||||||
// // "fim": {
|
"chat": {
|
||||||
// // "start": "",
|
// "completion": [
|
||||||
// // "middle": "",
|
// {
|
||||||
// // "end": ""
|
// "role": "system",
|
||||||
// // },
|
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||||
// "chat": {
|
// },
|
||||||
// "completion": [
|
// {
|
||||||
// {
|
// "role": "user",
|
||||||
// "role": "system",
|
// "content": "Complete the following code: \n\n{code}"
|
||||||
// "message": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief.\n\n{context}",
|
// }
|
||||||
// },
|
// ],
|
||||||
// {
|
// "generation": [
|
||||||
// "role": "user",
|
// {
|
||||||
// "message": "Complete the following code: \n\n{code}"
|
// "role": "system",
|
||||||
// }
|
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||||
// ],
|
// },
|
||||||
// "generation": [
|
// {
|
||||||
// {
|
// "role": "user",
|
||||||
// "role": "system",
|
// "content": "Complete the following code: \n\n{code}"
|
||||||
// "message": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
// }
|
||||||
// },
|
// ],
|
||||||
// {
|
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||||
// "role": "user",
|
},
|
||||||
// "message": "Complete the following code: \n\n{code}"
|
"n_ctx": 2048,
|
||||||
// }
|
"n_gpu_layers": 35,
|
||||||
// ]
|
}))?;
|
||||||
// },
|
let llama_cpp = LLaMACPP::new(configuration).unwrap();
|
||||||
// "n_ctx": 2048,
|
let prompt = Prompt::default_with_cursor();
|
||||||
// "n_gpu_layers": 1000,
|
let response = llama_cpp.do_completion(&prompt).await?;
|
||||||
// }
|
assert!(!response.insert_text.is_empty());
|
||||||
// },
|
Ok(())
|
||||||
// });
|
}
|
||||||
// let configuration = Configuration::new(args).unwrap();
|
}
|
||||||
// let _model = LlamaCPP::new(configuration).unwrap();
|
|
||||||
// // let output = model.do_completion("def fibon").unwrap();
|
|
||||||
// // println!("{}", output.insert_text);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|||||||
@@ -31,20 +31,23 @@ impl Model {
|
|||||||
.unwrap_or_else(|| Ok(1000))?;
|
.unwrap_or_else(|| Ok(1000))?;
|
||||||
|
|
||||||
// Initialize the model_params
|
// Initialize the model_params
|
||||||
let model_params = {
|
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
|
||||||
#[cfg(feature = "cublas")]
|
|
||||||
if !params.disable_gpu {
|
|
||||||
LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers)
|
|
||||||
} else {
|
|
||||||
LlamaModelParams::default()
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "cublas"))]
|
|
||||||
LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load the model
|
// Load the model
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
debug!("Loading model at path: {:?}", model_path);
|
debug!("Loading model at path: {:?}", model_path);
|
||||||
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("LOADED THE MODEL");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
eprintln!();
|
||||||
|
|
||||||
// Get n_ctx if set in kwargs
|
// Get n_ctx if set in kwargs
|
||||||
// As a default we set it to 2048
|
// As a default we set it to 2048
|
||||||
@@ -65,7 +68,7 @@ impl Model {
|
|||||||
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
||||||
// initialize the context
|
// initialize the context
|
||||||
let ctx_params = LlamaContextParams::default()
|
let ctx_params = LlamaContextParams::default()
|
||||||
.with_n_ctx(Some(self.n_ctx.clone()))
|
.with_n_ctx(Some(self.n_ctx))
|
||||||
.with_n_batch(self.n_ctx.get());
|
.with_n_batch(self.n_ctx.get());
|
||||||
|
|
||||||
let mut ctx = self
|
let mut ctx = self
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
|
|||||||
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
||||||
match configuration.config.transformer {
|
match configuration.config.transformer {
|
||||||
ValidTransformerBackend::LLaMACPP(model_gguf) => {
|
ValidTransformerBackend::LLaMACPP(model_gguf) => {
|
||||||
Ok(Box::new(llama_cpp::LlamaCPP::new(model_gguf)?))
|
Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?))
|
||||||
}
|
}
|
||||||
ValidTransformerBackend::OpenAI(openai_config) => {
|
ValidTransformerBackend::OpenAI(openai_config) => {
|
||||||
Ok(Box::new(openai::OpenAI::new(openai_config)))
|
Ok(Box::new(openai::OpenAI::new(openai_config)))
|
||||||
|
|||||||
Reference in New Issue
Block a user