Start working on the chat feature

This commit is contained in:
Silas Marvin
2024-02-26 20:03:29 -08:00
parent 418ccb81ff
commit 28b7b1b74e
12 changed files with 182 additions and 53 deletions

25
Cargo.lock generated
View File

@@ -141,6 +141,9 @@ name = "cc"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
dependencies = [
"libc",
]
[[package]]
name = "cexpr"
@@ -619,8 +622,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "llama-cpp-2"
version = "0.1.25"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-8-metal-on-mac#8c61f584e7aa200581b711147e685821190aa025"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c"
dependencies = [
"llama-cpp-sys-2",
"thiserror",
@@ -629,8 +633,9 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.25"
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-8-metal-on-mac#8c61f584e7aa200581b711147e685821190aa025"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36"
dependencies = [
"bindgen",
"cc",
@@ -662,6 +667,7 @@ dependencies = [
"llama-cpp-2",
"lsp-server",
"lsp-types",
"minijinja",
"once_cell",
"parking_lot",
"rand",
@@ -674,6 +680,8 @@ dependencies = [
[[package]]
name = "lsp-server"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095"
dependencies = [
"crossbeam-channel",
"log",
@@ -716,6 +724,15 @@ version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
[[package]]
name = "minijinja"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce"
dependencies = [
"serde",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"

View File

@@ -7,8 +7,8 @@ edition = "2021"
[dependencies]
anyhow = "1.0.75"
# lsp-server = "0.7.4"
lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
lsp-server = "0.7.4"
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
lsp-types = "0.94.1"
ropey = "1.6.1"
serde = "1.0.190"
@@ -19,8 +19,9 @@ tokenizers = "0.14.1"
parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
# llama-cpp-2 = "0.1.27"
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
llama-cpp-2 = "0.1.31"
# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
minijinja = "1.0.12"
[features]
default = []

View File

@@ -3,6 +3,8 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::memory_backends::Prompt;
#[cfg(target_os = "macos")]
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
@@ -21,6 +23,20 @@ pub enum ValidTransformerBackend {
PostgresML,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub message: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Chat {
pub completion: Option<Vec<ChatMessage>>,
pub generation: Option<Vec<ChatMessage>>,
pub chat_template: Option<String>,
pub chat_format: Option<String>,
}
#[derive(Clone, Deserialize)]
pub struct FIM {
pub start: String,
@@ -56,18 +72,6 @@ impl Default for ValidMemoryConfiguration {
}
}
#[derive(Clone, Deserialize)]
struct ChatMessages {
role: String,
message: String,
}
#[derive(Clone, Deserialize)]
struct Chat {
completion: Option<Vec<ChatMessages>>,
generation: Option<Vec<ChatMessages>>,
}
#[derive(Clone, Deserialize)]
pub struct Model {
pub repository: String,
@@ -230,6 +234,14 @@ impl Configuration {
panic!("We currently only support gguf models using llama cpp")
}
}
pub fn get_chat(&self) -> Option<&Chat> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf.chat.as_ref()
} else {
panic!("We currently only support gguf models using llama cpp")
}
}
}
#[cfg(test)]

View File

@@ -11,6 +11,8 @@ use std::{sync::Arc, thread};
mod configuration;
mod custom_requests;
mod memory_backends;
mod template;
mod tokenizer;
mod transformer_backends;
mod utils;
mod worker;
@@ -80,7 +82,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let thread_memory_backend = memory_backend.clone();
let thread_last_worker_request = last_worker_request.clone();
let thread_connection = connection.clone();
// TODO: Pass some backend into here
thread::spawn(move || {
Worker::new(
transformer_backend,

View File

@@ -3,9 +3,9 @@ use lsp_types::TextDocumentPositionParams;
use ropey::Rope;
use std::collections::HashMap;
use crate::configuration::Configuration;
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
use super::MemoryBackend;
use super::{MemoryBackend, Prompt};
pub struct FileStore {
configuration: Configuration,
@@ -34,7 +34,7 @@ impl MemoryBackend for FileStore {
.to_string())
}
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt> {
let mut rope = self
.file_map
.get(position.text_document.uri.as_str())
@@ -45,13 +45,14 @@ impl MemoryBackend for FileStore {
+ position.position.character as usize;
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
match self.configuration.get_fim() {
let code = match self.configuration.get_fim() {
Some(fim) if rope.len_chars() != cursor_index => {
let max_length = self.configuration.get_maximum_context_length();
let max_length =
characters_to_estimated_tokens(self.configuration.get_maximum_context_length());
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (start - cursor_index)));
.min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(end, &fim.end);
rope.insert(cursor_index, &fim.middle);
rope.insert(start, &fim.start);
@@ -64,18 +65,21 @@ impl MemoryBackend for FileStore {
+ fim.end.chars().count(),
)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
rope_slice.to_string()
}
_ => {
let start = cursor_index
.checked_sub(self.configuration.get_maximum_context_length())
.checked_sub(characters_to_estimated_tokens(
self.configuration.get_maximum_context_length(),
))
.unwrap_or(0);
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
rope_slice.to_string()
}
};
Ok(Prompt::new("".to_string(), code))
}
fn opened_text_document(

View File

@@ -7,6 +7,18 @@ use crate::configuration::{Configuration, ValidMemoryBackend};
pub mod file_store;
#[derive(Debug)]
pub struct Prompt {
pub context: String,
pub code: String,
}
impl Prompt {
fn new(context: String, code: String) -> Self {
Self { context, code }
}
}
pub trait MemoryBackend {
fn init(&self) -> anyhow::Result<()> {
Ok(())
@@ -14,8 +26,7 @@ pub trait MemoryBackend {
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
// Should return an enum of either chat messages or just a prompt string
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt>;
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
}

39
src/template.rs Normal file
View File

@@ -0,0 +1,39 @@
use crate::{
configuration::{Chat, ChatMessage, Configuration},
tokenizer::Tokenizer,
};
use hf_hub::api::sync::{Api, ApiRepo};
// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
// const CHATML_BOS_TOKEN: &str = "<s>";
// const CHATML_EOS_TOKEN: &str = "<|im_end|>";
// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = "<s>";
// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = "</s>";
// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
pub struct Template {
configuration: Configuration,
}
// impl Template {
// pub fn new(configuration: Configuration) -> Self {
// Self { configuration }
// }
// }
pub fn apply_prompt(
chat_messages: Vec<ChatMessage>,
chat: &Chat,
tokenizer: Option<&Tokenizer>,
) -> anyhow::Result<String> {
// If we have the chat template apply it
// If we have the chat_format see if we have it set
// If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file
anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file")
}

7
src/tokenizer.rs Normal file
View File

@@ -0,0 +1,7 @@
pub struct Tokenizer {}
impl Tokenizer {
pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result<Option<Self>> {
unimplemented!()
}
}

View File

@@ -3,7 +3,11 @@ use hf_hub::api::sync::Api;
use super::TransformerBackend;
use crate::{
configuration::Configuration,
configuration::{Chat, Configuration},
memory_backends::Prompt,
template::{apply_prompt, Template},
tokenizer::Tokenizer,
utils::format_chat_messages,
worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
@@ -15,6 +19,7 @@ use model::Model;
pub struct LlamaCPP {
model: Model,
configuration: Configuration,
tokenizer: Option<Tokenizer>,
}
impl LlamaCPP {
@@ -27,29 +32,42 @@ impl LlamaCPP {
.context("Model `name` is required when using GGUF models")?;
let repo = api.model(model.repository.to_owned());
let model_path = repo.get(&name)?;
let tokenizer: Option<Tokenizer> = Tokenizer::maybe_from_repo(repo)?;
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
Ok(Self {
model,
configuration,
tokenizer,
})
}
}
impl TransformerBackend for LlamaCPP {
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse> {
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
let prompt = match self.configuration.get_chat() {
Some(c) => {
if let Some(completion_messages) = &c.completion {
let chat_messages = format_chat_messages(completion_messages, prompt);
apply_prompt(chat_messages, c, self.tokenizer.as_ref())?
} else {
prompt.code.to_owned()
}
}
None => prompt.code.to_owned(),
};
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
self.model
.complete(prompt, max_new_tokens)
.complete(&prompt, max_new_tokens)
.map(|insert_text| DoCompletionResponse { insert_text })
}
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse> {
let max_new_tokens = self.configuration.get_max_new_tokens().generation;
self.model
.complete(prompt, max_new_tokens)
.map(|generated_text| DoGenerateResponse { generated_text })
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
unimplemented!()
// let max_new_tokens = self.configuration.get_max_new_tokens().generation;
// self.model
// .complete(prompt, max_new_tokens)
// .map(|generated_text| DoGenerateResponse { generated_text })
}
fn do_generate_stream(
@@ -74,8 +92,6 @@ mod tests {
},
"macos": {
"model_gguf": {
// "repository": "deepseek-coder-6.7b-base",
// "name": "Q4_K_M.gguf",
"repository": "stabilityai/stable-code-3b",
"name": "stable-code-3b-Q5_K_M.gguf",
"max_new_tokens": {
@@ -110,7 +126,6 @@ mod tests {
]
},
"n_ctx": 2048,
"n_threads": 8,
"n_gpu_layers": 1000,
}
},
@@ -118,7 +133,7 @@ mod tests {
});
let configuration = Configuration::new(args).unwrap();
let model = LlamaCPP::new(configuration).unwrap();
let output = model.do_completion("def fibon").unwrap();
println!("{}", output.insert_text);
// let output = model.do_completion("def fibon").unwrap();
// println!("{}", output.insert_text);
}
}

View File

@@ -1,5 +1,6 @@
use crate::{
configuration::{Configuration, ValidTransformerBackend},
memory_backends::Prompt,
worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
@@ -9,8 +10,8 @@ pub mod llama_cpp;
pub trait TransformerBackend {
// Should all take an enum of chat messages or just a string for completion
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse>;
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse>;
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
fn do_generate_stream(
&self,
request: &GenerateStreamRequest,

View File

@@ -1,5 +1,7 @@
use lsp_server::ResponseError;
use crate::{configuration::ChatMessage, memory_backends::Prompt};
pub trait ToResponseError {
fn to_response_error(&self, code: i32) -> ResponseError;
}
@@ -13,3 +15,20 @@ impl ToResponseError for anyhow::Error {
}
}
}
pub fn characters_to_estimated_tokens(characters: usize) -> usize {
characters * 4
}
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
messages
.iter()
.map(|m| ChatMessage {
role: m.role.to_owned(),
message: m
.message
.replace("{context}", &prompt.context)
.replace("{code}", &prompt.code),
})
.collect()
}

View File

@@ -36,6 +36,8 @@ impl GenerateRequest {
}
}
// The generate stream is not yet ready but we don't want to remove it
#[allow(dead_code)]
#[derive(Clone)]
pub struct GenerateStreamRequest {
id: RequestId,
@@ -98,10 +100,10 @@ impl Worker {
.memory_backend
.lock()
.get_filter_text(&request.params.text_document_position)?;
eprintln!("\nPROMPT**************\n{}\n******************\n", prompt);
eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt);
let response = self.transformer_backend.do_completion(&prompt)?;
eprintln!(
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{}\n&&&&&&&&&&&&&&&&&&\n",
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n",
response.insert_text
);
let completion_text_edit = TextEdit::new(
@@ -142,7 +144,7 @@ impl Worker {
.memory_backend
.lock()
.build_prompt(&request.params.text_document_position)?;
eprintln!("\nPROMPT*************\n{}\n************\n", prompt);
eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt);
let response = self.transformer_backend.do_generate(&prompt)?;
let result = GenerateResult {
generated_text: response.generated_text,