mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-20 16:04:21 +01:00
Start working on the chat feature
This commit is contained in:
25
Cargo.lock
generated
25
Cargo.lock
generated
@@ -141,6 +141,9 @@ name = "cc"
|
|||||||
version = "1.0.86"
|
version = "1.0.86"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
|
checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cexpr"
|
name = "cexpr"
|
||||||
@@ -619,8 +622,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cpp-2"
|
name = "llama-cpp-2"
|
||||||
version = "0.1.25"
|
version = "0.1.31"
|
||||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-8-metal-on-mac#8c61f584e7aa200581b711147e685821190aa025"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f8a5342c3eb45011e7e3646e22c5b8fcd3f25e049f0eb9618048e40b0027a59c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"llama-cpp-sys-2",
|
"llama-cpp-sys-2",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
@@ -629,8 +633,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cpp-sys-2"
|
name = "llama-cpp-sys-2"
|
||||||
version = "0.1.25"
|
version = "0.1.31"
|
||||||
source = "git+https://github.com/SilasMarvin/llama-cpp-rs?branch=silas-8-metal-on-mac#8c61f584e7aa200581b711147e685821190aa025"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e1813a55afed6298991bcaaee040b49a83b473b3571ce37b4bbaa4b294ebcc36"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen",
|
"bindgen",
|
||||||
"cc",
|
"cc",
|
||||||
@@ -662,6 +667,7 @@ dependencies = [
|
|||||||
"llama-cpp-2",
|
"llama-cpp-2",
|
||||||
"lsp-server",
|
"lsp-server",
|
||||||
"lsp-types",
|
"lsp-types",
|
||||||
|
"minijinja",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"rand",
|
"rand",
|
||||||
@@ -674,6 +680,8 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "lsp-server"
|
name = "lsp-server"
|
||||||
version = "0.7.6"
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crossbeam-channel",
|
"crossbeam-channel",
|
||||||
"log",
|
"log",
|
||||||
@@ -716,6 +724,15 @@ version = "2.7.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
|
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minijinja"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.75"
|
anyhow = "1.0.75"
|
||||||
# lsp-server = "0.7.4"
|
lsp-server = "0.7.4"
|
||||||
lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
|
# lsp-server = { path = "../rust-analyzer/lib/lsp-server" }
|
||||||
lsp-types = "0.94.1"
|
lsp-types = "0.94.1"
|
||||||
ropey = "1.6.1"
|
ropey = "1.6.1"
|
||||||
serde = "1.0.190"
|
serde = "1.0.190"
|
||||||
@@ -19,8 +19,9 @@ tokenizers = "0.14.1"
|
|||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
directories = "5.0.1"
|
directories = "5.0.1"
|
||||||
# llama-cpp-2 = "0.1.27"
|
llama-cpp-2 = "0.1.31"
|
||||||
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
# llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
|
||||||
|
minijinja = "1.0.12"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ use serde::Deserialize;
|
|||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::memory_backends::Prompt;
|
||||||
|
|
||||||
#[cfg(target_os = "macos")]
|
#[cfg(target_os = "macos")]
|
||||||
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
||||||
|
|
||||||
@@ -21,6 +23,20 @@ pub enum ValidTransformerBackend {
|
|||||||
PostgresML,
|
PostgresML,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Chat {
|
||||||
|
pub completion: Option<Vec<ChatMessage>>,
|
||||||
|
pub generation: Option<Vec<ChatMessage>>,
|
||||||
|
pub chat_template: Option<String>,
|
||||||
|
pub chat_format: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Deserialize)]
|
||||||
pub struct FIM {
|
pub struct FIM {
|
||||||
pub start: String,
|
pub start: String,
|
||||||
@@ -56,18 +72,6 @@ impl Default for ValidMemoryConfiguration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
|
||||||
struct ChatMessages {
|
|
||||||
role: String,
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
|
||||||
struct Chat {
|
|
||||||
completion: Option<Vec<ChatMessages>>,
|
|
||||||
generation: Option<Vec<ChatMessages>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Deserialize)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
pub repository: String,
|
pub repository: String,
|
||||||
@@ -230,6 +234,14 @@ impl Configuration {
|
|||||||
panic!("We currently only support gguf models using llama cpp")
|
panic!("We currently only support gguf models using llama cpp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_chat(&self) -> Option<&Chat> {
|
||||||
|
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||||
|
model_gguf.chat.as_ref()
|
||||||
|
} else {
|
||||||
|
panic!("We currently only support gguf models using llama cpp")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ use std::{sync::Arc, thread};
|
|||||||
mod configuration;
|
mod configuration;
|
||||||
mod custom_requests;
|
mod custom_requests;
|
||||||
mod memory_backends;
|
mod memory_backends;
|
||||||
|
mod template;
|
||||||
|
mod tokenizer;
|
||||||
mod transformer_backends;
|
mod transformer_backends;
|
||||||
mod utils;
|
mod utils;
|
||||||
mod worker;
|
mod worker;
|
||||||
@@ -80,7 +82,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
|||||||
let thread_memory_backend = memory_backend.clone();
|
let thread_memory_backend = memory_backend.clone();
|
||||||
let thread_last_worker_request = last_worker_request.clone();
|
let thread_last_worker_request = last_worker_request.clone();
|
||||||
let thread_connection = connection.clone();
|
let thread_connection = connection.clone();
|
||||||
// TODO: Pass some backend into here
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
Worker::new(
|
Worker::new(
|
||||||
transformer_backend,
|
transformer_backend,
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ use lsp_types::TextDocumentPositionParams;
|
|||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::configuration::Configuration;
|
use crate::{configuration::Configuration, utils::characters_to_estimated_tokens};
|
||||||
|
|
||||||
use super::MemoryBackend;
|
use super::{MemoryBackend, Prompt};
|
||||||
|
|
||||||
pub struct FileStore {
|
pub struct FileStore {
|
||||||
configuration: Configuration,
|
configuration: Configuration,
|
||||||
@@ -34,7 +34,7 @@ impl MemoryBackend for FileStore {
|
|||||||
.to_string())
|
.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt> {
|
||||||
let mut rope = self
|
let mut rope = self
|
||||||
.file_map
|
.file_map
|
||||||
.get(position.text_document.uri.as_str())
|
.get(position.text_document.uri.as_str())
|
||||||
@@ -45,13 +45,14 @@ impl MemoryBackend for FileStore {
|
|||||||
+ position.position.character as usize;
|
+ position.position.character as usize;
|
||||||
|
|
||||||
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
|
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
|
||||||
match self.configuration.get_fim() {
|
let code = match self.configuration.get_fim() {
|
||||||
Some(fim) if rope.len_chars() != cursor_index => {
|
Some(fim) if rope.len_chars() != cursor_index => {
|
||||||
let max_length = self.configuration.get_maximum_context_length();
|
let max_length =
|
||||||
|
characters_to_estimated_tokens(self.configuration.get_maximum_context_length());
|
||||||
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
|
||||||
let end = rope
|
let end = rope
|
||||||
.len_chars()
|
.len_chars()
|
||||||
.min(cursor_index + (max_length - (start - cursor_index)));
|
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||||
rope.insert(end, &fim.end);
|
rope.insert(end, &fim.end);
|
||||||
rope.insert(cursor_index, &fim.middle);
|
rope.insert(cursor_index, &fim.middle);
|
||||||
rope.insert(start, &fim.start);
|
rope.insert(start, &fim.start);
|
||||||
@@ -64,18 +65,21 @@ impl MemoryBackend for FileStore {
|
|||||||
+ fim.end.chars().count(),
|
+ fim.end.chars().count(),
|
||||||
)
|
)
|
||||||
.context("Error getting rope slice")?;
|
.context("Error getting rope slice")?;
|
||||||
Ok(rope_slice.to_string())
|
rope_slice.to_string()
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let start = cursor_index
|
let start = cursor_index
|
||||||
.checked_sub(self.configuration.get_maximum_context_length())
|
.checked_sub(characters_to_estimated_tokens(
|
||||||
|
self.configuration.get_maximum_context_length(),
|
||||||
|
))
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
.get_slice(start..cursor_index)
|
.get_slice(start..cursor_index)
|
||||||
.context("Error getting rope slice")?;
|
.context("Error getting rope slice")?;
|
||||||
Ok(rope_slice.to_string())
|
rope_slice.to_string()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
Ok(Prompt::new("".to_string(), code))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn opened_text_document(
|
fn opened_text_document(
|
||||||
|
|||||||
@@ -7,6 +7,18 @@ use crate::configuration::{Configuration, ValidMemoryBackend};
|
|||||||
|
|
||||||
pub mod file_store;
|
pub mod file_store;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Prompt {
|
||||||
|
pub context: String,
|
||||||
|
pub code: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Prompt {
|
||||||
|
fn new(context: String, code: String) -> Self {
|
||||||
|
Self { context, code }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait MemoryBackend {
|
pub trait MemoryBackend {
|
||||||
fn init(&self) -> anyhow::Result<()> {
|
fn init(&self) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -14,8 +26,7 @@ pub trait MemoryBackend {
|
|||||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||||
// Should return an enum of either chat messages or just a prompt string
|
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<Prompt>;
|
||||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
|
||||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
39
src/template.rs
Normal file
39
src/template.rs
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
use crate::{
|
||||||
|
configuration::{Chat, ChatMessage, Configuration},
|
||||||
|
tokenizer::Tokenizer,
|
||||||
|
};
|
||||||
|
use hf_hub::api::sync::{Api, ApiRepo};
|
||||||
|
|
||||||
|
// // Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||||
|
// const CHATML_CHAT_TEMPLATE: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
|
||||||
|
// const CHATML_BOS_TOKEN: &str = "<s>";
|
||||||
|
// const CHATML_EOS_TOKEN: &str = "<|im_end|>";
|
||||||
|
|
||||||
|
// // Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
|
// const MISTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
||||||
|
// const MISTRAL_INSTRUCT_BOS_TOKEN: &str = "<s>";
|
||||||
|
// const MISTRAL_INSTRUCT_EOS_TOKEN: &str = "</s>";
|
||||||
|
|
||||||
|
// // Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
|
// const MIXTRAL_INSTRUCT_CHAT_TEMPLATE: &str = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
||||||
|
|
||||||
|
pub struct Template {
|
||||||
|
configuration: Configuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
// impl Template {
|
||||||
|
// pub fn new(configuration: Configuration) -> Self {
|
||||||
|
// Self { configuration }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn apply_prompt(
|
||||||
|
chat_messages: Vec<ChatMessage>,
|
||||||
|
chat: &Chat,
|
||||||
|
tokenizer: Option<&Tokenizer>,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
// If we have the chat template apply it
|
||||||
|
// If we have the chat_format see if we have it set
|
||||||
|
// If we don't have the chat_format set here, try and get the chat_template from the tokenizer_config.json file
|
||||||
|
anyhow::bail!("Please set chat_template or chat_format. Could not find the information in the tokenizer_config.json file")
|
||||||
|
}
|
||||||
7
src/tokenizer.rs
Normal file
7
src/tokenizer.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
pub struct Tokenizer {}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
pub fn maybe_from_repo(repo: ApiRepo) -> anyhow::Result<Option<Self>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,11 @@ use hf_hub::api::sync::Api;
|
|||||||
|
|
||||||
use super::TransformerBackend;
|
use super::TransformerBackend;
|
||||||
use crate::{
|
use crate::{
|
||||||
configuration::Configuration,
|
configuration::{Chat, Configuration},
|
||||||
|
memory_backends::Prompt,
|
||||||
|
template::{apply_prompt, Template},
|
||||||
|
tokenizer::Tokenizer,
|
||||||
|
utils::format_chat_messages,
|
||||||
worker::{
|
worker::{
|
||||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||||
},
|
},
|
||||||
@@ -15,6 +19,7 @@ use model::Model;
|
|||||||
pub struct LlamaCPP {
|
pub struct LlamaCPP {
|
||||||
model: Model,
|
model: Model,
|
||||||
configuration: Configuration,
|
configuration: Configuration,
|
||||||
|
tokenizer: Option<Tokenizer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaCPP {
|
impl LlamaCPP {
|
||||||
@@ -27,29 +32,42 @@ impl LlamaCPP {
|
|||||||
.context("Model `name` is required when using GGUF models")?;
|
.context("Model `name` is required when using GGUF models")?;
|
||||||
let repo = api.model(model.repository.to_owned());
|
let repo = api.model(model.repository.to_owned());
|
||||||
let model_path = repo.get(&name)?;
|
let model_path = repo.get(&name)?;
|
||||||
|
let tokenizer: Option<Tokenizer> = Tokenizer::maybe_from_repo(repo)?;
|
||||||
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
|
let model = Model::new(model_path, configuration.get_model_kwargs()?)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
configuration,
|
configuration,
|
||||||
|
tokenizer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TransformerBackend for LlamaCPP {
|
impl TransformerBackend for LlamaCPP {
|
||||||
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse> {
|
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||||
|
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
|
||||||
|
let prompt = match self.configuration.get_chat() {
|
||||||
|
Some(c) => {
|
||||||
|
if let Some(completion_messages) = &c.completion {
|
||||||
|
let chat_messages = format_chat_messages(completion_messages, prompt);
|
||||||
|
apply_prompt(chat_messages, c, self.tokenizer.as_ref())?
|
||||||
|
} else {
|
||||||
|
prompt.code.to_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => prompt.code.to_owned(),
|
||||||
|
};
|
||||||
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
|
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
|
||||||
self.model
|
self.model
|
||||||
.complete(prompt, max_new_tokens)
|
.complete(&prompt, max_new_tokens)
|
||||||
.map(|insert_text| DoCompletionResponse { insert_text })
|
.map(|insert_text| DoCompletionResponse { insert_text })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse> {
|
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||||
let max_new_tokens = self.configuration.get_max_new_tokens().generation;
|
unimplemented!()
|
||||||
self.model
|
// let max_new_tokens = self.configuration.get_max_new_tokens().generation;
|
||||||
.complete(prompt, max_new_tokens)
|
// self.model
|
||||||
.map(|generated_text| DoGenerateResponse { generated_text })
|
// .complete(prompt, max_new_tokens)
|
||||||
|
// .map(|generated_text| DoGenerateResponse { generated_text })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn do_generate_stream(
|
fn do_generate_stream(
|
||||||
@@ -74,8 +92,6 @@ mod tests {
|
|||||||
},
|
},
|
||||||
"macos": {
|
"macos": {
|
||||||
"model_gguf": {
|
"model_gguf": {
|
||||||
// "repository": "deepseek-coder-6.7b-base",
|
|
||||||
// "name": "Q4_K_M.gguf",
|
|
||||||
"repository": "stabilityai/stable-code-3b",
|
"repository": "stabilityai/stable-code-3b",
|
||||||
"name": "stable-code-3b-Q5_K_M.gguf",
|
"name": "stable-code-3b-Q5_K_M.gguf",
|
||||||
"max_new_tokens": {
|
"max_new_tokens": {
|
||||||
@@ -110,7 +126,6 @@ mod tests {
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"n_ctx": 2048,
|
"n_ctx": 2048,
|
||||||
"n_threads": 8,
|
|
||||||
"n_gpu_layers": 1000,
|
"n_gpu_layers": 1000,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -118,7 +133,7 @@ mod tests {
|
|||||||
});
|
});
|
||||||
let configuration = Configuration::new(args).unwrap();
|
let configuration = Configuration::new(args).unwrap();
|
||||||
let model = LlamaCPP::new(configuration).unwrap();
|
let model = LlamaCPP::new(configuration).unwrap();
|
||||||
let output = model.do_completion("def fibon").unwrap();
|
// let output = model.do_completion("def fibon").unwrap();
|
||||||
println!("{}", output.insert_text);
|
// println!("{}", output.insert_text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
configuration::{Configuration, ValidTransformerBackend},
|
configuration::{Configuration, ValidTransformerBackend},
|
||||||
|
memory_backends::Prompt,
|
||||||
worker::{
|
worker::{
|
||||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||||
},
|
},
|
||||||
@@ -9,8 +10,8 @@ pub mod llama_cpp;
|
|||||||
|
|
||||||
pub trait TransformerBackend {
|
pub trait TransformerBackend {
|
||||||
// Should all take an enum of chat messages or just a string for completion
|
// Should all take an enum of chat messages or just a string for completion
|
||||||
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse>;
|
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse>;
|
||||||
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse>;
|
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse>;
|
||||||
fn do_generate_stream(
|
fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateStreamRequest,
|
request: &GenerateStreamRequest,
|
||||||
|
|||||||
19
src/utils.rs
19
src/utils.rs
@@ -1,5 +1,7 @@
|
|||||||
use lsp_server::ResponseError;
|
use lsp_server::ResponseError;
|
||||||
|
|
||||||
|
use crate::{configuration::ChatMessage, memory_backends::Prompt};
|
||||||
|
|
||||||
pub trait ToResponseError {
|
pub trait ToResponseError {
|
||||||
fn to_response_error(&self, code: i32) -> ResponseError;
|
fn to_response_error(&self, code: i32) -> ResponseError;
|
||||||
}
|
}
|
||||||
@@ -13,3 +15,20 @@ impl ToResponseError for anyhow::Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn characters_to_estimated_tokens(characters: usize) -> usize {
|
||||||
|
characters * 4
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format_chat_messages(messages: &Vec<ChatMessage>, prompt: &Prompt) -> Vec<ChatMessage> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| ChatMessage {
|
||||||
|
role: m.role.to_owned(),
|
||||||
|
message: m
|
||||||
|
.message
|
||||||
|
.replace("{context}", &prompt.context)
|
||||||
|
.replace("{code}", &prompt.code),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ impl GenerateRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The generate stream is not yet ready but we don't want to remove it
|
||||||
|
#[allow(dead_code)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct GenerateStreamRequest {
|
pub struct GenerateStreamRequest {
|
||||||
id: RequestId,
|
id: RequestId,
|
||||||
@@ -98,10 +100,10 @@ impl Worker {
|
|||||||
.memory_backend
|
.memory_backend
|
||||||
.lock()
|
.lock()
|
||||||
.get_filter_text(&request.params.text_document_position)?;
|
.get_filter_text(&request.params.text_document_position)?;
|
||||||
eprintln!("\nPROMPT**************\n{}\n******************\n", prompt);
|
eprintln!("\nPROMPT**************\n{:?}\n******************\n", prompt);
|
||||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
let response = self.transformer_backend.do_completion(&prompt)?;
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{}\n&&&&&&&&&&&&&&&&&&\n",
|
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{:?}\n&&&&&&&&&&&&&&&&&&\n",
|
||||||
response.insert_text
|
response.insert_text
|
||||||
);
|
);
|
||||||
let completion_text_edit = TextEdit::new(
|
let completion_text_edit = TextEdit::new(
|
||||||
@@ -142,7 +144,7 @@ impl Worker {
|
|||||||
.memory_backend
|
.memory_backend
|
||||||
.lock()
|
.lock()
|
||||||
.build_prompt(&request.params.text_document_position)?;
|
.build_prompt(&request.params.text_document_position)?;
|
||||||
eprintln!("\nPROMPT*************\n{}\n************\n", prompt);
|
eprintln!("\nPROMPT*************\n{:?}\n************\n", prompt);
|
||||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
let response = self.transformer_backend.do_generate(&prompt)?;
|
||||||
let result = GenerateResult {
|
let result = GenerateResult {
|
||||||
generated_text: response.generated_text,
|
generated_text: response.generated_text,
|
||||||
|
|||||||
Reference in New Issue
Block a user