mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 23:14:28 +01:00
Added templating and some other great things
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -712,6 +712,7 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"xxhash-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -770,12 +771,20 @@ version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
|
||||
|
||||
[[package]]
|
||||
name = "memo-map"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83"
|
||||
|
||||
[[package]]
|
||||
name = "minijinja"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce"
|
||||
dependencies = [
|
||||
"memo-map",
|
||||
"self_cell",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -1307,6 +1316,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "self_cell"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.197"
|
||||
@@ -1897,6 +1912,12 @@ version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
|
||||
|
||||
[[package]]
|
||||
name = "xxhash-rust"
|
||||
version = "0.8.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03"
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.7.0"
|
||||
|
||||
@@ -21,9 +21,10 @@ once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
# llama-cpp-2 = "0.1.31"
|
||||
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||
minijinja = "1.0.12"
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
7
editors/vscode/package-lock.json
generated
7
editors/vscode/package-lock.json
generated
@@ -15,6 +15,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.0",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"typescript": "^5.3.3"
|
||||
},
|
||||
"engines": {
|
||||
@@ -30,6 +31,12 @@
|
||||
"undici-types": "~5.26.4"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/uuid": {
|
||||
"version": "9.0.8",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz",
|
||||
"integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/vscode": {
|
||||
"version": "1.85.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.85.0.tgz",
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.0",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"typescript": "^5.3.3"
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
ServerOptions,
|
||||
TransportKind
|
||||
} from 'vscode-languageclient/node';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
// import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
let client: LanguageClient;
|
||||
|
||||
@@ -34,7 +34,7 @@ export function activate(context: vscode.ExtensionContext) {
|
||||
|
||||
// Register generate function
|
||||
const generateCommand = 'lsp-ai.generate';
|
||||
const generateCommandHandler = (editor) => {
|
||||
const generateCommandHandler = (editor: vscode.TextEditor) => {
|
||||
let params = {
|
||||
textDocument: {
|
||||
uri: editor.document.uri.toString(),
|
||||
@@ -42,7 +42,6 @@ export function activate(context: vscode.ExtensionContext) {
|
||||
position: editor.selection.active
|
||||
};
|
||||
client.sendRequest("textDocument/generate", params).then(result => {
|
||||
console.log("RECEIVED RESULT", result);
|
||||
editor.edit((edit) => {
|
||||
edit.insert(editor.selection.active, result["generatedText"]);
|
||||
});
|
||||
@@ -52,28 +51,43 @@ export function activate(context: vscode.ExtensionContext) {
|
||||
};
|
||||
context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateCommand, generateCommandHandler));
|
||||
|
||||
|
||||
// Register functions
|
||||
const generateStreamCommand = 'lsp-ai.generateStream';
|
||||
const generateStreamCommandHandler = (editor) => {
|
||||
// This function is not ready to go
|
||||
// const generateStreamCommand = 'lsp-ai.generateStream';
|
||||
// const generateStreamCommandHandler = (editor: vscode.TextEditor) => {
|
||||
// let params = {
|
||||
// textDocument: {
|
||||
// uri: editor.document.uri.toString(),
|
||||
// },
|
||||
// position: editor.selection.active,
|
||||
// partialResultToken: uuidv4()
|
||||
// };
|
||||
// console.log("PARAMS: ", params);
|
||||
// client.sendRequest("textDocument/generateStream", params).then(result => {
|
||||
// console.log("RECEIVED RESULT", result);
|
||||
// editor.edit((edit) => {
|
||||
// edit.insert(editor.selection.active, result["generatedText"]);
|
||||
// });
|
||||
// }).catch(error => {
|
||||
// console.error("Error making generate request", error);
|
||||
// });
|
||||
// };
|
||||
// context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler));
|
||||
|
||||
vscode.languages.registerInlineCompletionItemProvider({ pattern: '**' },
|
||||
{
|
||||
provideInlineCompletionItems: async (document: vscode.TextDocument, position: vscode.Position) => {
|
||||
let params = {
|
||||
textDocument: {
|
||||
uri: editor.document.uri.toString(),
|
||||
uri: document.uri.toString(),
|
||||
},
|
||||
position: editor.selection.active,
|
||||
partialResultToken: uuidv4()
|
||||
position: position
|
||||
};
|
||||
console.log("PARAMS: ", params);
|
||||
client.sendRequest("textDocument/generateStream", params).then(result => {
|
||||
console.log("RECEIVED RESULT", result);
|
||||
editor.edit((edit) => {
|
||||
edit.insert(editor.selection.active, result["generatedText"]);
|
||||
});
|
||||
}).catch(error => {
|
||||
console.error("Error making generate request", error);
|
||||
});
|
||||
};
|
||||
context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler));
|
||||
const result = await client.sendRequest("textDocument/generate", params);
|
||||
return [new vscode.InlineCompletionItem(result["generatedText"])];
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function deactivate(): Thenable<void> | undefined {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -21,7 +21,7 @@ pub enum ValidTransformerBackend {
|
||||
PostgresML,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
@@ -241,3 +241,59 @@ impl Configuration {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn macos_model_gguf() {
|
||||
let args = json!({
|
||||
"memory": {
|
||||
"file_store": {}
|
||||
},
|
||||
"macos": {
|
||||
"model_gguf": {
|
||||
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
|
||||
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
|
||||
"max_new_tokens": {
|
||||
"completion": 32,
|
||||
"generation": 256,
|
||||
},
|
||||
"fim": {
|
||||
"start": "<fim_prefix>",
|
||||
"middle": "<fim_suffix>",
|
||||
"end": "<fim_middle>"
|
||||
},
|
||||
"chat": {
|
||||
"completion": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Complete the following code: \n\n{code}"
|
||||
}
|
||||
],
|
||||
"generation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Complete the following code: \n\n{code}"
|
||||
}
|
||||
],
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||
},
|
||||
"n_ctx": 2048,
|
||||
"n_gpu_layers": 35,
|
||||
}
|
||||
},
|
||||
});
|
||||
Configuration::new(args).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
mod configuration;
|
||||
mod custom_requests;
|
||||
mod memory_backends;
|
||||
mod template;
|
||||
mod transformer_backends;
|
||||
mod utils;
|
||||
mod worker;
|
||||
@@ -25,7 +26,6 @@ use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
|
||||
|
||||
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
|
||||
|
||||
// Taken directly from: https://github.com/rust-lang/rust-analyzer
|
||||
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
|
||||
notification.method == N::METHOD
|
||||
}
|
||||
@@ -48,7 +48,7 @@ fn main() -> Result<()> {
|
||||
FmtSubscriber::builder()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
|
||||
.with_max_level(tracing::Level::TRACE)
|
||||
// .with_max_level(tracing::Level::TRACE)
|
||||
.init();
|
||||
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
|
||||
35
src/template.rs
Normal file
35
src/template.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use minijinja::{context, Environment, ErrorKind};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::configuration::ChatMessage;
|
||||
|
||||
static MINININJA_ENVIRONMENT: Lazy<Mutex<Environment>> =
|
||||
Lazy::new(|| Mutex::new(Environment::new()));
|
||||
|
||||
fn template_name_from_template_string(template: &str) -> String {
|
||||
xxhash_rust::xxh3::xxh3_64(template.as_bytes()).to_string()
|
||||
}
|
||||
|
||||
pub fn apply_chat_template(
|
||||
template: &str,
|
||||
chat_messages: Vec<ChatMessage>,
|
||||
bos_token: &str,
|
||||
eos_token: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let template_name = template_name_from_template_string(template);
|
||||
let mut env = MINININJA_ENVIRONMENT.lock();
|
||||
let template = match env.get_template(&template_name) {
|
||||
Ok(template) => template,
|
||||
Err(e) => match e.kind() {
|
||||
ErrorKind::TemplateNotFound => {
|
||||
env.add_template_owned(template_name.clone(), template.to_owned())?;
|
||||
env.get_template(&template_name)?
|
||||
}
|
||||
_ => anyhow::bail!(e.to_string()),
|
||||
},
|
||||
};
|
||||
Ok(template.render(
|
||||
context!(messages => chat_messages, bos_token => bos_token, eos_token => eos_token),
|
||||
)?)
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
use anyhow::Context;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::api::sync::ApiBuilder;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
configuration::Configuration,
|
||||
memory_backends::Prompt,
|
||||
template::apply_chat_template,
|
||||
utils::format_chat_messages,
|
||||
worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
@@ -23,7 +24,7 @@ pub struct LlamaCPP {
|
||||
impl LlamaCPP {
|
||||
#[instrument]
|
||||
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
|
||||
let api = Api::new()?;
|
||||
let api = ApiBuilder::new().with_progress(true).build()?;
|
||||
let model = configuration.get_model()?;
|
||||
let name = model
|
||||
.name
|
||||
@@ -45,8 +46,13 @@ impl LlamaCPP {
|
||||
Some(c) => {
|
||||
if let Some(completion_messages) = &c.completion {
|
||||
let chat_messages = format_chat_messages(completion_messages, prompt);
|
||||
self.model
|
||||
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
|
||||
if let Some(chat_template) = &c.chat_template {
|
||||
let bos_token = self.model.get_bos_token()?;
|
||||
let eos_token = self.model.get_eos_token()?;
|
||||
apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)?
|
||||
} else {
|
||||
self.model.apply_chat_template(chat_messages, None)?
|
||||
}
|
||||
} else {
|
||||
prompt.code.to_owned()
|
||||
}
|
||||
@@ -59,8 +65,9 @@ impl LlamaCPP {
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
#[instrument(skip(self))]
|
||||
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
let prompt = &prompt.code;
|
||||
debug!("Prompt string for LLM: {}", prompt);
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
||||
self.model
|
||||
.complete(&prompt, max_new_tokens)
|
||||
@@ -69,8 +76,9 @@ impl TransformerBackend for LlamaCPP {
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
let prompt = self.get_prompt_string(prompt)?;
|
||||
// let prompt = self.get_prompt_string(prompt)?;
|
||||
// debug!("Prompt string for LLM: {}", prompt);
|
||||
let prompt = &prompt.code;
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
|
||||
self.model
|
||||
.complete(&prompt, max_new_tokens)
|
||||
|
||||
@@ -64,7 +64,9 @@ impl Model {
|
||||
#[instrument(skip(self))]
|
||||
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
|
||||
// initialize the context
|
||||
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
|
||||
let ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(Some(self.n_ctx.clone()))
|
||||
.with_n_batch(self.n_ctx.get());
|
||||
|
||||
let mut ctx = self
|
||||
.model
|
||||
@@ -157,4 +159,16 @@ impl Model {
|
||||
.model
|
||||
.apply_chat_template(template, llama_chat_messages, true)?)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub fn get_eos_token(&self) -> anyhow::Result<String> {
|
||||
let token = self.model.token_eos();
|
||||
Ok(self.model.token_to_str(token)?)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub fn get_bos_token(&self) -> anyhow::Result<String> {
|
||||
let token = self.model.token_bos();
|
||||
Ok(self.model.token_to_str(token)?)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user