From 28b7b1b74ec5cdebda583fb33768c41c5076dbb0 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:03:29 -0800 Subject: [PATCH] Start working on the chat feature --- Cargo.lock | 25 +++++++++++-- Cargo.toml | 9 +++-- src/configuration.rs | 36 ++++++++++++------ src/main.rs | 3 +- src/memory_backends/file_store.rs | 24 +++++++----- src/memory_backends/mod.rs | 15 +++++++- src/template.rs | 39 ++++++++++++++++++++ src/tokenizer.rs | 7 ++++ src/transformer_backends/llama_cpp/mod.rs | 45 +++++++++++++++-------- src/transformer_backends/mod.rs | 5 ++- src/utils.rs | 19 ++++++++++ src/worker.rs | 8 ++-- 12 files changed, 182 insertions(+), 53 deletions(-) create mode 100644 src/template.rs create mode 100644 src/tokenizer.rs diff --git a/Cargo.lock b/Cargo.lock index 69ad672..1fbca2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -141,6 +141,9 @@ name = "cc" version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" +dependencies = [ + "libc", +] [[package]] name = "cexpr" @@ -619,8 +622,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama-cpp-2" -version = "0.1.25" -source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-8-metal-on-mac#8c61f584e7aa200581b711147e685821190aa025" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c" dependencies = [ "llama-cpp-sys-2", "thiserror", @@ -629,8 +633,9 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" -version = "0.1.25" -source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-8-metal-on-mac#8c61f584e7aa200581b711147e685821190aa025" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36" dependencies = [ "bindgen", "cc", @@ -662,6 +667,7 @@ dependencies = [ "llama-cpp-2", "lsp-server", "lsp-types", + "minijinja", "once_cell", "parking_lot", "rand", @@ -674,6 +680,8 @@ dependencies = [ [[package]] name = "lsp-server" version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095" dependencies = [ "crossbeam-channel", "log", @@ -716,6 +724,15 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "minijinja" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce" +dependencies = [ + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index 6bd0bba..47057a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ edition = "2021" [dependencies] anyhow = "1.0.75" -# lsp-server = "0.7.4" -lsp-server = { path = "../rust-analyzer/lib/lsp-server" } +lsp-server = "0.7.4" +# lsp-server = { path = "../rust-analyzer/lib/lsp-server" } lsp-types = "0.94.1" ropey = "1.6.1" serde = "1.0.190" @@ -19,8 +19,9 @@ tokenizers = "0.14.1" parking_lot = "0.12.1" once_cell = "1.19.0" directories = "5.0.1" -# llama-cpp-2 = "0.1.27" -llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } +llama-cpp-2 = "0.1.31" +# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } +minijinja = "1.0.12" [features] default = [] diff --git a/src/configuration.rs b/src/configuration.rs index 08a26f0..d1379fb 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -3,6 +3,8 @@ use serde::Deserialize; use serde_json::{json, Value}; use std::collections::HashMap; +use crate::memory_backends::Prompt; + #[cfg(target_os = "macos")] const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024; @@ -21,6 +23,20 @@ pub enum ValidTransformerBackend { PostgresML, } +#[derive(Debug, Clone, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub message: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Chat { + pub completion: Option>, + pub generation: Option>, + pub chat_template: Option, + pub chat_format: Option, +} + #[derive(Clone, Deserialize)] pub struct FIM { pub start: String, @@ -56,18 +72,6 @@ impl Default for ValidMemoryConfiguration { } } -#[derive(Clone, Deserialize)] -struct ChatMessages { - role: String, - message: String, -} - -#[derive(Clone, Deserialize)] -struct Chat { - completion: Option>, - generation: Option>, -} - #[derive(Clone, Deserialize)] pub struct Model { pub repository: String, @@ -230,6 +234,14 @@ impl Configuration { panic!("We currently only support gguf models using llama cpp") } } + + pub fn get_chat(&self) -> Option<&Chat> { + if let Some(model_gguf) = &self.valid_config.transformer.model_gguf { + model_gguf.chat.as_ref() + } else { + panic!("We currently only support gguf models using llama cpp") + } + } } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index 2322463..cf610ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,8 @@ use std::{sync::Arc, thread}; mod configuration; mod custom_requests; mod memory_backends; +mod template; +mod tokenizer; mod transformer_backends; mod utils; mod worker; @@ -80,7 +82,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let thread_memory_backend = memory_backend.clone(); let thread_last_worker_request = last_worker_request.clone(); let thread_connection = connection.clone(); - // TODO: Pass some backend into here thread::spawn(move || { Worker::new( transformer_backend, diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 746b1fd..2f3232d 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -3,9 +3,9 @@ use lsp_types::TextDocumentPositionParams; use ropey::Rope; use std::collections::HashMap; -use crate::configuration::Configuration; +use crate::{configuration::Configuration, utils::characters_to_estimated_tokens}; -use super::MemoryBackend; +use super::{MemoryBackend, Prompt}; pub struct FileStore { configuration: Configuration, @@ -34,7 +34,7 @@ impl MemoryBackend for FileStore { .to_string()) } - fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result { let mut rope = self .file_map .get(position.text_document.uri.as_str()) @@ -45,13 +45,14 @@ impl MemoryBackend for FileStore { + position.position.character as usize; // We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file - match self.configuration.get_fim() { + let code = match self.configuration.get_fim() { Some(fim) if rope.len_chars() != cursor_index => { - let max_length = self.configuration.get_maximum_context_length(); + let max_length = + characters_to_estimated_tokens(self.configuration.get_maximum_context_length()); let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0); let end = rope .len_chars() - .min(cursor_index + (max_length - (start - cursor_index))); + .min(cursor_index + (max_length - (cursor_index - start))); rope.insert(end, &fim.end); rope.insert(cursor_index, &fim.middle); rope.insert(start, &fim.start); @@ -64,18 +65,21 @@ impl MemoryBackend for FileStore { + fim.end.chars().count(), ) .context("Error getting rope slice")?; - Ok(rope_slice.to_string()) + rope_slice.to_string() } _ => { let start = cursor_index - .checked_sub(self.configuration.get_maximum_context_length()) + .checked_sub(characters_to_estimated_tokens( + self.configuration.get_maximum_context_length(), + )) .unwrap_or(0); let rope_slice = rope .get_slice(start..cursor_index) .context("Error getting rope slice")?; - Ok(rope_slice.to_string()) + rope_slice.to_string() } - } + }; + Ok(Prompt::new("".to_string(), code)) } fn opened_text_document( diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index 64232ac..ac7840e 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -7,6 +7,18 @@ use crate::configuration::{Configuration, ValidMemoryBackend}; pub mod file_store; +#[derive(Debug)] +pub struct Prompt { + pub context: String, + pub code: String, +} + +impl Prompt { + fn new(context: String, code: String) -> Self { + Self { context, code } + } +} + pub trait MemoryBackend { fn init(&self) -> anyhow::Result<()> { Ok(()) @@ -14,8 +26,7 @@ pub trait MemoryBackend { fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>; fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>; - // Should return an enum of either chat messages or just a prompt string - fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result; + fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result; fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result; } diff --git a/src/template.rs b/src/template.rs new file mode 100644 index 0000000..412fb4f --- /dev/null +++ b/src/template.rs @@ -0,0 +1,39 @@ +use crate::{ + configuration::{Chat, ChatMessage, Configuration}, + tokenizer::Tokenizer, +}; +use hf_hub::api::sync::{Api, ApiRepo}; + +// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json +// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"; +// const CHATML_BOS_TOKEN: &str = ""; +// const CHATML_EOS_TOKEN: &str = "<|im_end|>"; + +// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json +// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"; +// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = ""; +// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = ""; + +// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json +// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"; + +pub struct Template { + configuration: Configuration, +} + +// impl Template { +// pub fn new(configuration: Configuration) -> Self { +// Self { configuration } +// } +// } + +pub fn apply_prompt( + chat_messages: Vec, + chat: &Chat, + tokenizer: Option<&Tokenizer>, +) -> anyhow::Result { + // If we have the chat template apply it + // If we have the chat_format see if we have it set + // If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file + anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file") +} diff --git a/src/tokenizer.rs b/src/tokenizer.rs new file mode 100644 index 0000000..f656264 --- /dev/null +++ b/src/tokenizer.rs @@ -0,0 +1,7 @@ +pub struct Tokenizer {} + +impl Tokenizer { + pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result> { + unimplemented!() + } +} diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index 037f9af..fab1c41 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -3,7 +3,11 @@ use hf_hub::api::sync::Api; use super::TransformerBackend; use crate::{ - configuration::Configuration, + configuration::{Chat, Configuration}, + memory_backends::Prompt, + template::{apply_prompt, Template}, + tokenizer::Tokenizer, + utils::format_chat_messages, worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, @@ -15,6 +19,7 @@ use model::Model; pub struct LlamaCPP { model: Model, configuration: Configuration, + tokenizer: Option, } impl LlamaCPP { @@ -27,29 +32,42 @@ impl LlamaCPP { .context("Model `name` is required when using GGUF models")?; let repo = api.model(model.repository.to_owned()); let model_path = repo.get(&name)?; - + let tokenizer: Option = Tokenizer::maybe_from_repo(repo)?; let model = Model::new(model_path, configuration.get_model_kwargs()?)?; - Ok(Self { model, configuration, + tokenizer, }) } } impl TransformerBackend for LlamaCPP { - fn do_completion(&self, prompt: &str) -> anyhow::Result { + fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + // We need to check that they not only set the `chat` key, but they set the `completion` sub key + let prompt = match self.configuration.get_chat() { + Some(c) => { + if let Some(completion_messages) = &c.completion { + let chat_messages = format_chat_messages(completion_messages, prompt); + apply_prompt(chat_messages, c, self.tokenizer.as_ref())? + } else { + prompt.code.to_owned() + } + } + None => prompt.code.to_owned(), + }; let max_new_tokens = self.configuration.get_max_new_tokens().completion; self.model - .complete(prompt, max_new_tokens) + .complete(&prompt, max_new_tokens) .map(|insert_text| DoCompletionResponse { insert_text }) } - fn do_generate(&self, prompt: &str) -> anyhow::Result { - let max_new_tokens = self.configuration.get_max_new_tokens().generation; - self.model - .complete(prompt, max_new_tokens) - .map(|generated_text| DoGenerateResponse { generated_text }) + fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + unimplemented!() + // let max_new_tokens = self.configuration.get_max_new_tokens().generation; + // self.model + // .complete(prompt, max_new_tokens) + // .map(|generated_text| DoGenerateResponse { generated_text }) } fn do_generate_stream( @@ -74,8 +92,6 @@ mod tests { }, "macos": { "model_gguf": { - // "repository": "deepseek-coder-6.7b-base", - // "name": "Q4_K_M.gguf", "repository": "stabilityai/stable-code-3b", "name": "stable-code-3b-Q5_K_M.gguf", "max_new_tokens": { @@ -110,7 +126,6 @@ mod tests { ] }, "n_ctx": 2048, - "n_threads": 8, "n_gpu_layers": 1000, } }, @@ -118,7 +133,7 @@ mod tests { }); let configuration = Configuration::new(args).unwrap(); let model = LlamaCPP::new(configuration).unwrap(); - let output = model.do_completion("def fibon").unwrap(); - println!("{}", output.insert_text); + // let output = model.do_completion("def fibon").unwrap(); + // println!("{}", output.insert_text); } } diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 7340d1b..1f9fd4c 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -1,5 +1,6 @@ use crate::{ configuration::{Configuration, ValidTransformerBackend}, + memory_backends::Prompt, worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, }, @@ -9,8 +10,8 @@ pub mod llama_cpp; pub trait TransformerBackend { // Should all take an enum of chat messages or just a string for completion - fn do_completion(&self, prompt: &str) -> anyhow::Result; - fn do_generate(&self, prompt: &str) -> anyhow::Result; + fn do_completion(&self, prompt: &Prompt) -> anyhow::Result; + fn do_generate(&self, prompt: &Prompt) -> anyhow::Result; fn do_generate_stream( &self, request: &GenerateStreamRequest, diff --git a/src/utils.rs b/src/utils.rs index cb2ae49..1dc5b06 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,7 @@ use lsp_server::ResponseError; +use crate::{configuration::ChatMessage, memory_backends::Prompt}; + pub trait ToResponseError { fn to_response_error(&self, code: i32) -> ResponseError; } @@ -13,3 +15,20 @@ impl ToResponseError for anyhow::Error { } } } + +pub fn characters_to_estimated_tokens(characters: usize) -> usize { + characters * 4 +} + +pub fn format_chat_messages(messages: &Vec, prompt: &Prompt) -> Vec { + messages + .iter() + .map(|m| ChatMessage { + role: m.role.to_owned(), + message: m + .message + .replace("{context}", &prompt.context) + .replace("{code}", &prompt.code), + }) + .collect() +} diff --git a/src/worker.rs b/src/worker.rs index 641fdd5..49cb7b9 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -36,6 +36,8 @@ impl GenerateRequest { } } +// The generate stream is not yet ready but we don't want to remove it +#[allow(dead_code)] #[derive(Clone)] pub struct GenerateStreamRequest { id: RequestId, @@ -98,10 +100,10 @@ impl Worker { .memory_backend .lock() .get_filter_text(&request.params.text_document_position)?; - eprintln!("\nPROMPT**************\n{}\n******************\n", prompt); + eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt); let response = self.transformer_backend.do_completion(&prompt)?; eprintln!( - "\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{}\n&&&&&&&&&&&&&&&&&&\n", + "\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n", response.insert_text ); let completion_text_edit = TextEdit::new( @@ -142,7 +144,7 @@ impl Worker { .memory_backend .lock() .build_prompt(&request.params.text_document_position)?; - eprintln!("\nPROMPT*************\n{}\n************\n", prompt); + eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt); let response = self.transformer_backend.do_generate(&prompt)?; let result = GenerateResult { generated_text: response.generated_text,