From aa7c4061cff31601f6fcd7897394c5f38a65a61d Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 8 Mar 2024 15:12:37 -0800 Subject: [PATCH] Added templating and some other great things --- Cargo.lock | 21 +++++++ Cargo.toml | 3 +- editors/vscode/package-lock.json | 7 +++ editors/vscode/package.json | 1 + editors/vscode/src/index.ts | 66 +++++++++++++-------- src/configuration.rs | 60 ++++++++++++++++++- src/main.rs | 4 +- src/template.rs | 35 +++++++++++ src/transformer_backends/llama_cpp/mod.rs | 22 ++++--- src/transformer_backends/llama_cpp/model.rs | 16 ++++- 10 files changed, 196 insertions(+), 39 deletions(-) create mode 100644 src/template.rs diff --git a/Cargo.lock b/Cargo.lock index f623ed5..4caebd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -712,6 +712,7 @@ dependencies = [ "tokenizers", "tracing", "tracing-subscriber", + "xxhash-rust", ] [[package]] @@ -770,12 +771,20 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "memo-map" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83" + [[package]] name = "minijinja" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce" dependencies = [ + "memo-map", + "self_cell", "serde", ] @@ -1307,6 +1316,12 @@ dependencies = [ "libc", ] +[[package]] +name = "self_cell" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba" + [[package]] name = "serde" version = "1.0.197" @@ -1897,6 +1912,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +[[package]] +name = "xxhash-rust" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03" + [[package]] name = "zeroize" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 06fa1d7..d5cef2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,9 +21,10 @@ once_cell = "1.19.0" directories = "5.0.1" # llama-cpp-2 = "0.1.31" llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } -minijinja = "1.0.12" +minijinja = { version = "1.0.12", features = ["loader"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing = "0.1.40" +xxhash-rust = { version = "0.8.5", features = ["xxh3"] } [features] default = [] diff --git a/editors/vscode/package-lock.json b/editors/vscode/package-lock.json index fdee0ac..3619607 100644 --- a/editors/vscode/package-lock.json +++ b/editors/vscode/package-lock.json @@ -15,6 +15,7 @@ }, "devDependencies": { "@types/node": "^20.11.0", + "@types/uuid": "^9.0.8", "typescript": "^5.3.3" }, "engines": { @@ -30,6 +31,12 @@ "undici-types": "~5.26.4" } }, + "node_modules/@types/uuid": { + "version": "9.0.8", + "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz", + "integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==", + "dev": true + }, "node_modules/@types/vscode": { "version": "1.85.0", "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.85.0.tgz", diff --git a/editors/vscode/package.json b/editors/vscode/package.json index f143db2..b5ffcdf 100644 --- a/editors/vscode/package.json +++ b/editors/vscode/package.json @@ -38,6 +38,7 @@ }, "devDependencies": { "@types/node": "^20.11.0", + "@types/uuid": "^9.0.8", "typescript": "^5.3.3" }, "dependencies": { diff --git a/editors/vscode/src/index.ts b/editors/vscode/src/index.ts index d035af9..4e108ee 100644 --- a/editors/vscode/src/index.ts +++ b/editors/vscode/src/index.ts @@ -1,18 +1,18 @@ -import * as vscode from 'vscode'; +import * as vscode from 'vscode'; import { LanguageClient, LanguageClientOptions, ServerOptions, TransportKind } from 'vscode-languageclient/node'; -import { v4 as uuidv4 } from 'uuid'; +// import { v4 as uuidv4 } from 'uuid'; let client: LanguageClient; export function activate(context: vscode.ExtensionContext) { // Configure the server options let serverOptions: ServerOptions = { - command: "lsp-ai", + command: "lsp-ai", transport: TransportKind.stdio, }; @@ -34,7 +34,7 @@ export function activate(context: vscode.ExtensionContext) { // Register generate function const generateCommand = 'lsp-ai.generate'; - const generateCommandHandler = (editor) => { + const generateCommandHandler = (editor: vscode.TextEditor) => { let params = { textDocument: { uri: editor.document.uri.toString(), @@ -42,7 +42,6 @@ export function activate(context: vscode.ExtensionContext) { position: editor.selection.active }; client.sendRequest("textDocument/generate", params).then(result => { - console.log("RECEIVED RESULT", result); editor.edit((edit) => { edit.insert(editor.selection.active, result["generatedText"]); }); @@ -52,28 +51,43 @@ export function activate(context: vscode.ExtensionContext) { }; context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateCommand, generateCommandHandler)); - // Register functions - const generateStreamCommand = 'lsp-ai.generateStream'; - const generateStreamCommandHandler = (editor) => { - let params = { - textDocument: { - uri: editor.document.uri.toString(), - }, - position: editor.selection.active, - partialResultToken: uuidv4() - }; - console.log("PARAMS: ", params); - client.sendRequest("textDocument/generateStream", params).then(result => { - console.log("RECEIVED RESULT", result); - editor.edit((edit) => { - edit.insert(editor.selection.active, result["generatedText"]); - }); - }).catch(error => { - console.error("Error making generate request", error); - }); - }; - context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler)); + // This function is not ready to go + // const generateStreamCommand = 'lsp-ai.generateStream'; + // const generateStreamCommandHandler = (editor: vscode.TextEditor) => { + // let params = { + // textDocument: { + // uri: editor.document.uri.toString(), + // }, + // position: editor.selection.active, + // partialResultToken: uuidv4() + // }; + // console.log("PARAMS: ", params); + // client.sendRequest("textDocument/generateStream", params).then(result => { + // console.log("RECEIVED RESULT", result); + // editor.edit((edit) => { + // edit.insert(editor.selection.active, result["generatedText"]); + // }); + // }).catch(error => { + // console.error("Error making generate request", error); + // }); + // }; + // context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler)); + + vscode.languages.registerInlineCompletionItemProvider({ pattern: '**' }, + { + provideInlineCompletionItems: async (document: vscode.TextDocument, position: vscode.Position) => { + let params = { + textDocument: { + uri: document.uri.toString(), + }, + position: position + }; + const result = await client.sendRequest("textDocument/generate", params); + return [new vscode.InlineCompletionItem(result["generatedText"])]; + } + } + ); } export function deactivate(): Thenable | undefined { diff --git a/src/configuration.rs b/src/configuration.rs index ab0e413..56360ba 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; @@ -21,7 +21,7 @@ pub enum ValidTransformerBackend { PostgresML, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct ChatMessage { pub role: String, pub content: String, @@ -241,3 +241,59 @@ impl Configuration { } } } + +#[cfg(test)] +mod test { + use super::*; + use serde_json::json; + + #[test] + fn macos_model_gguf() { + let args = json!({ + "memory": { + "file_store": {} + }, + "macos": { + "model_gguf": { + "repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF", + "name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf", + "max_new_tokens": { + "completion": 32, + "generation": 256, + }, + "fim": { + "start": "", + "middle": "", + "end": "" + }, + "chat": { + "completion": [ + { + "role": "system", + "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}", + }, + { + "role": "user", + "content": "Complete the following code: \n\n{code}" + } + ], + "generation": [ + { + "role": "system", + "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}", + }, + { + "role": "user", + "content": "Complete the following code: \n\n{code}" + } + ], + "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}" + }, + "n_ctx": 2048, + "n_gpu_layers": 35, + } + }, + }); + Configuration::new(args).unwrap(); + } +} diff --git a/src/main.rs b/src/main.rs index 9e32eba..9246607 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber}; mod configuration; mod custom_requests; mod memory_backends; +mod template; mod transformer_backends; mod utils; mod worker; @@ -25,7 +26,6 @@ use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest}; use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest}; -// Taken directly from: https://github.com/rust-lang/rust-analyzer fn notification_is(notification: &Notification) -> bool { notification.method == N::METHOD } @@ -48,7 +48,7 @@ fn main() -> Result<()> { FmtSubscriber::builder() .with_writer(std::io::stderr) .with_env_filter(EnvFilter::from_env("LSP_AI_LOG")) - .with_max_level(tracing::Level::TRACE) + // .with_max_level(tracing::Level::TRACE) .init(); let (connection, io_threads) = Connection::stdio(); diff --git a/src/template.rs b/src/template.rs new file mode 100644 index 0000000..3d75141 --- /dev/null +++ b/src/template.rs @@ -0,0 +1,35 @@ +use minijinja::{context, Environment, ErrorKind}; +use once_cell::sync::Lazy; +use parking_lot::Mutex; + +use crate::configuration::ChatMessage; + +static MINININJA_ENVIRONMENT: Lazy> = + Lazy::new(|| Mutex::new(Environment::new())); + +fn template_name_from_template_string(template: &str) -> String { + xxhash_rust::xxh3::xxh3_64(template.as_bytes()).to_string() +} + +pub fn apply_chat_template( + template: &str, + chat_messages: Vec, + bos_token: &str, + eos_token: &str, +) -> anyhow::Result { + let template_name = template_name_from_template_string(template); + let mut env = MINININJA_ENVIRONMENT.lock(); + let template = match env.get_template(&template_name) { + Ok(template) => template, + Err(e) => match e.kind() { + ErrorKind::TemplateNotFound => { + env.add_template_owned(template_name.clone(), template.to_owned())?; + env.get_template(&template_name)? + } + _ => anyhow::bail!(e.to_string()), + }, + }; + Ok(template.render( + context!(messages => chat_messages, bos_token => bos_token, eos_token => eos_token), + )?) +} diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index e8b03d8..86307ce 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -1,11 +1,12 @@ use anyhow::Context; -use hf_hub::api::sync::Api; +use hf_hub::api::sync::ApiBuilder; use tracing::{debug, instrument}; use super::TransformerBackend; use crate::{ configuration::Configuration, memory_backends::Prompt, + template::apply_chat_template, utils::format_chat_messages, worker::{ DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, @@ -23,7 +24,7 @@ pub struct LlamaCPP { impl LlamaCPP { #[instrument] pub fn new(configuration: Configuration) -> anyhow::Result { - let api = Api::new()?; + let api = ApiBuilder::new().with_progress(true).build()?; let model = configuration.get_model()?; let name = model .name @@ -45,8 +46,13 @@ impl LlamaCPP { Some(c) => { if let Some(completion_messages) = &c.completion { let chat_messages = format_chat_messages(completion_messages, prompt); - self.model - .apply_chat_template(chat_messages, c.chat_template.to_owned())? + if let Some(chat_template) = &c.chat_template { + let bos_token = self.model.get_bos_token()?; + let eos_token = self.model.get_eos_token()?; + apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)? + } else { + self.model.apply_chat_template(chat_messages, None)? + } } else { prompt.code.to_owned() } @@ -59,8 +65,9 @@ impl LlamaCPP { impl TransformerBackend for LlamaCPP { #[instrument(skip(self))] fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { - let prompt = self.get_prompt_string(prompt)?; - // debug!("Prompt string for LLM: {}", prompt); + // let prompt = self.get_prompt_string(prompt)?; + let prompt = &prompt.code; + debug!("Prompt string for LLM: {}", prompt); let max_new_tokens = self.configuration.get_max_new_tokens()?.completion; self.model .complete(&prompt, max_new_tokens) @@ -69,8 +76,9 @@ impl TransformerBackend for LlamaCPP { #[instrument(skip(self))] fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { - let prompt = self.get_prompt_string(prompt)?; + // let prompt = self.get_prompt_string(prompt)?; // debug!("Prompt string for LLM: {}", prompt); + let prompt = &prompt.code; let max_new_tokens = self.configuration.get_max_new_tokens()?.completion; self.model .complete(&prompt, max_new_tokens) diff --git a/src/transformer_backends/llama_cpp/model.rs b/src/transformer_backends/llama_cpp/model.rs index ea1237d..4c049b4 100644 --- a/src/transformer_backends/llama_cpp/model.rs +++ b/src/transformer_backends/llama_cpp/model.rs @@ -64,7 +64,9 @@ impl Model { #[instrument(skip(self))] pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result { // initialize the context - let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone())); + let ctx_params = LlamaContextParams::default() + .with_n_ctx(Some(self.n_ctx.clone())) + .with_n_batch(self.n_ctx.get()); let mut ctx = self .model @@ -157,4 +159,16 @@ impl Model { .model .apply_chat_template(template, llama_chat_messages, true)?) } + + #[instrument(skip(self))] + pub fn get_eos_token(&self) -> anyhow::Result { + let token = self.model.token_eos(); + Ok(self.model.token_to_str(token)?) + } + + #[instrument(skip(self))] + pub fn get_bos_token(&self) -> anyhow::Result { + let token = self.model.token_bos(); + Ok(self.model.token_to_str(token)?) + } }