Added templating and some other great things

This commit is contained in:
Silas Marvin
2024-03-08 15:12:37 -08:00
parent d818cdca6d
commit aa7c4061cf
10 changed files with 196 additions and 39 deletions

21
Cargo.lock generated
View File

@@ -712,6 +712,7 @@ dependencies = [
"tokenizers", "tokenizers",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"xxhash-rust",
] ]
[[package]] [[package]]
@@ -770,12 +771,20 @@ version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
[[package]]
name = "memo-map"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83"
[[package]] [[package]]
name = "minijinja" name = "minijinja"
version = "1.0.12" version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce" checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce"
dependencies = [ dependencies = [
"memo-map",
"self_cell",
"serde", "serde",
] ]
@@ -1307,6 +1316,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "self_cell"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.197" version = "1.0.197"
@@ -1897,6 +1912,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
[[package]]
name = "xxhash-rust"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03"
[[package]] [[package]]
name = "zeroize" name = "zeroize"
version = "1.7.0" version = "1.7.0"

View File

@@ -21,9 +21,10 @@ once_cell = "1.19.0"
directories = "5.0.1" directories = "5.0.1"
# llama-cpp-2 = "0.1.31" # llama-cpp-2 = "0.1.31"
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" } llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
minijinja = "1.0.12" minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40" tracing = "0.1.40"
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
[features] [features]
default = [] default = []

View File

@@ -15,6 +15,7 @@
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^20.11.0", "@types/node": "^20.11.0",
"@types/uuid": "^9.0.8",
"typescript": "^5.3.3" "typescript": "^5.3.3"
}, },
"engines": { "engines": {
@@ -30,6 +31,12 @@
"undici-types": "~5.26.4" "undici-types": "~5.26.4"
} }
}, },
"node_modules/@types/uuid": {
"version": "9.0.8",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz",
"integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==",
"dev": true
},
"node_modules/@types/vscode": { "node_modules/@types/vscode": {
"version": "1.85.0", "version": "1.85.0",
"resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.85.0.tgz", "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.85.0.tgz",

View File

@@ -38,6 +38,7 @@
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^20.11.0", "@types/node": "^20.11.0",
"@types/uuid": "^9.0.8",
"typescript": "^5.3.3" "typescript": "^5.3.3"
}, },
"dependencies": { "dependencies": {

View File

@@ -1,18 +1,18 @@
import * as vscode from 'vscode'; import * as vscode from 'vscode';
import { import {
LanguageClient, LanguageClient,
LanguageClientOptions, LanguageClientOptions,
ServerOptions, ServerOptions,
TransportKind TransportKind
} from 'vscode-languageclient/node'; } from 'vscode-languageclient/node';
import { v4 as uuidv4 } from 'uuid'; // import { v4 as uuidv4 } from 'uuid';
let client: LanguageClient; let client: LanguageClient;
export function activate(context: vscode.ExtensionContext) { export function activate(context: vscode.ExtensionContext) {
// Configure the server options // Configure the server options
let serverOptions: ServerOptions = { let serverOptions: ServerOptions = {
command: "lsp-ai", command: "lsp-ai",
transport: TransportKind.stdio, transport: TransportKind.stdio,
}; };
@@ -34,7 +34,7 @@ export function activate(context: vscode.ExtensionContext) {
// Register generate function // Register generate function
const generateCommand = 'lsp-ai.generate'; const generateCommand = 'lsp-ai.generate';
const generateCommandHandler = (editor) => { const generateCommandHandler = (editor: vscode.TextEditor) => {
let params = { let params = {
textDocument: { textDocument: {
uri: editor.document.uri.toString(), uri: editor.document.uri.toString(),
@@ -42,7 +42,6 @@ export function activate(context: vscode.ExtensionContext) {
position: editor.selection.active position: editor.selection.active
}; };
client.sendRequest("textDocument/generate", params).then(result => { client.sendRequest("textDocument/generate", params).then(result => {
console.log("RECEIVED RESULT", result);
editor.edit((edit) => { editor.edit((edit) => {
edit.insert(editor.selection.active, result["generatedText"]); edit.insert(editor.selection.active, result["generatedText"]);
}); });
@@ -52,28 +51,43 @@ export function activate(context: vscode.ExtensionContext) {
}; };
context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateCommand, generateCommandHandler)); context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateCommand, generateCommandHandler));
// Register functions // Register functions
const generateStreamCommand = 'lsp-ai.generateStream'; // This function is not ready to go
const generateStreamCommandHandler = (editor) => { // const generateStreamCommand = 'lsp-ai.generateStream';
let params = { // const generateStreamCommandHandler = (editor: vscode.TextEditor) => {
textDocument: { // let params = {
uri: editor.document.uri.toString(), // textDocument: {
}, // uri: editor.document.uri.toString(),
position: editor.selection.active, // },
partialResultToken: uuidv4() // position: editor.selection.active,
}; // partialResultToken: uuidv4()
console.log("PARAMS: ", params); // };
client.sendRequest("textDocument/generateStream", params).then(result => { // console.log("PARAMS: ", params);
console.log("RECEIVED RESULT", result); // client.sendRequest("textDocument/generateStream", params).then(result => {
editor.edit((edit) => { // console.log("RECEIVED RESULT", result);
edit.insert(editor.selection.active, result["generatedText"]); // editor.edit((edit) => {
}); // edit.insert(editor.selection.active, result["generatedText"]);
}).catch(error => { // });
console.error("Error making generate request", error); // }).catch(error => {
}); // console.error("Error making generate request", error);
}; // });
context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler)); // };
// context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler));
vscode.languages.registerInlineCompletionItemProvider({ pattern: '**' },
{
provideInlineCompletionItems: async (document: vscode.TextDocument, position: vscode.Position) => {
let params = {
textDocument: {
uri: document.uri.toString(),
},
position: position
};
const result = await client.sendRequest("textDocument/generate", params);
return [new vscode.InlineCompletionItem(result["generatedText"])];
}
}
);
} }
export function deactivate(): Thenable<void> | undefined { export function deactivate(): Thenable<void> | undefined {

View File

@@ -1,5 +1,5 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde::Deserialize; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
@@ -21,7 +21,7 @@ pub enum ValidTransformerBackend {
PostgresML, PostgresML,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessage { pub struct ChatMessage {
pub role: String, pub role: String,
pub content: String, pub content: String,
@@ -241,3 +241,59 @@ impl Configuration {
} }
} }
} }
#[cfg(test)]
mod test {
use super::*;
use serde_json::json;
#[test]
fn macos_model_gguf() {
let args = json!({
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
"completion": [
{
"role": "system",
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
},
{
"role": "user",
"content": "Complete the following code: \n\n{code}"
}
],
"generation": [
{
"role": "system",
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
},
{
"role": "user",
"content": "Complete the following code: \n\n{code}"
}
],
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
},
"n_ctx": 2048,
"n_gpu_layers": 35,
}
},
});
Configuration::new(args).unwrap();
}
}

View File

@@ -13,6 +13,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod configuration; mod configuration;
mod custom_requests; mod custom_requests;
mod memory_backends; mod memory_backends;
mod template;
mod transformer_backends; mod transformer_backends;
mod utils; mod utils;
mod worker; mod worker;
@@ -25,7 +26,6 @@ use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest}; use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
// Taken directly from: https://github.com/rust-lang/rust-analyzer
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool { fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD notification.method == N::METHOD
} }
@@ -48,7 +48,7 @@ fn main() -> Result<()> {
FmtSubscriber::builder() FmtSubscriber::builder()
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG")) .with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.with_max_level(tracing::Level::TRACE) // .with_max_level(tracing::Level::TRACE)
.init(); .init();
let (connection, io_threads) = Connection::stdio(); let (connection, io_threads) = Connection::stdio();

35
src/template.rs Normal file
View File

@@ -0,0 +1,35 @@
use minijinja::{context, Environment, ErrorKind};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use crate::configuration::ChatMessage;
static MINININJA_ENVIRONMENT: Lazy<Mutex<Environment>> =
Lazy::new(|| Mutex::new(Environment::new()));
fn template_name_from_template_string(template: &str) -> String {
xxhash_rust::xxh3::xxh3_64(template.as_bytes()).to_string()
}
pub fn apply_chat_template(
template: &str,
chat_messages: Vec<ChatMessage>,
bos_token: &str,
eos_token: &str,
) -> anyhow::Result<String> {
let template_name = template_name_from_template_string(template);
let mut env = MINININJA_ENVIRONMENT.lock();
let template = match env.get_template(&template_name) {
Ok(template) => template,
Err(e) => match e.kind() {
ErrorKind::TemplateNotFound => {
env.add_template_owned(template_name.clone(), template.to_owned())?;
env.get_template(&template_name)?
}
_ => anyhow::bail!(e.to_string()),
},
};
Ok(template.render(
context!(messages => chat_messages, bos_token => bos_token, eos_token => eos_token),
)?)
}

View File

@@ -1,11 +1,12 @@
use anyhow::Context; use anyhow::Context;
use hf_hub::api::sync::Api; use hf_hub::api::sync::ApiBuilder;
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use super::TransformerBackend; use super::TransformerBackend;
use crate::{ use crate::{
configuration::Configuration, configuration::Configuration,
memory_backends::Prompt, memory_backends::Prompt,
template::apply_chat_template,
utils::format_chat_messages, utils::format_chat_messages,
worker::{ worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
@@ -23,7 +24,7 @@ pub struct LlamaCPP {
impl LlamaCPP { impl LlamaCPP {
#[instrument] #[instrument]
pub fn new(configuration: Configuration) -> anyhow::Result<Self> { pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
let api = Api::new()?; let api = ApiBuilder::new().with_progress(true).build()?;
let model = configuration.get_model()?; let model = configuration.get_model()?;
let name = model let name = model
.name .name
@@ -45,8 +46,13 @@ impl LlamaCPP {
Some(c) => { Some(c) => {
if let Some(completion_messages) = &c.completion { if let Some(completion_messages) = &c.completion {
let chat_messages = format_chat_messages(completion_messages, prompt); let chat_messages = format_chat_messages(completion_messages, prompt);
self.model if let Some(chat_template) = &c.chat_template {
.apply_chat_template(chat_messages, c.chat_template.to_owned())? let bos_token = self.model.get_bos_token()?;
let eos_token = self.model.get_eos_token()?;
apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)?
} else {
self.model.apply_chat_template(chat_messages, None)?
}
} else { } else {
prompt.code.to_owned() prompt.code.to_owned()
} }
@@ -59,8 +65,9 @@ impl LlamaCPP {
impl TransformerBackend for LlamaCPP { impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))] #[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> { fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
let prompt = self.get_prompt_string(prompt)?; // let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt); let prompt = &prompt.code;
debug!("Prompt string for LLM: {}", prompt);
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion; let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model self.model
.complete(&prompt, max_new_tokens) .complete(&prompt, max_new_tokens)
@@ -69,8 +76,9 @@ impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))] #[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> { fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
let prompt = self.get_prompt_string(prompt)?; // let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt); // debug!("Prompt string for LLM: {}", prompt);
let prompt = &prompt.code;
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion; let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model self.model
.complete(&prompt, max_new_tokens) .complete(&prompt, max_new_tokens)

View File

@@ -64,7 +64,9 @@ impl Model {
#[instrument(skip(self))] #[instrument(skip(self))]
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> { pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
// initialize the context // initialize the context
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone())); let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(self.n_ctx.clone()))
.with_n_batch(self.n_ctx.get());
let mut ctx = self let mut ctx = self
.model .model
@@ -157,4 +159,16 @@ impl Model {
.model .model
.apply_chat_template(template, llama_chat_messages, true)?) .apply_chat_template(template, llama_chat_messages, true)?)
} }
#[instrument(skip(self))]
pub fn get_eos_token(&self) -> anyhow::Result<String> {
let token = self.model.token_eos();
Ok(self.model.token_to_str(token)?)
}
#[instrument(skip(self))]
pub fn get_bos_token(&self) -> anyhow::Result<String> {
let token = self.model.token_bos();
Ok(self.model.token_to_str(token)?)
}
} }