mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-20 16:04:21 +01:00
Added backends
This commit is contained in:
23
design.txt
Normal file
23
design.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# Overview
|
||||
|
||||
LSP AI should support multiple transform_backends:
|
||||
- Python - LLAMA CPP
|
||||
- Python - SOME OTHER LIBRARY
|
||||
- PostgresML
|
||||
|
||||
pub trait TransformBackend {
|
||||
// These all take memory backends as an argument
|
||||
do_completion()
|
||||
do_generate()
|
||||
do_generate_stream()
|
||||
}
|
||||
|
||||
LSP AI should support multiple memory_backends:
|
||||
- SIMPLE FILE STORE
|
||||
- IN MEMORY VECTOR STORE
|
||||
- PostgresML
|
||||
|
||||
pub trait MemoryBackend {
|
||||
// Some file change ones
|
||||
get_context() // Depending on the memory backend this will do very different things
|
||||
}
|
||||
@@ -24,7 +24,17 @@
|
||||
"command": "lsp-ai.generateStream",
|
||||
"title": "LSP AI Generate Stream"
|
||||
}
|
||||
]
|
||||
],
|
||||
"configuration": {
|
||||
"title": "Configuration",
|
||||
"properties": {
|
||||
"configuration.json": {
|
||||
"type": "json",
|
||||
"default": "{}",
|
||||
"description": "JSON configuration for LSP AI"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.0",
|
||||
|
||||
181
src/configuration.rs
Normal file
181
src/configuration.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
||||
|
||||
const DEFAULT_MAX_COMPLETION_TOKENS: usize = 32;
|
||||
const DEFAULT_MAX_GENERATION_TOKENS: usize = 256;
|
||||
|
||||
pub enum ValidMemoryBackend {
|
||||
FileStore,
|
||||
PostgresML,
|
||||
}
|
||||
|
||||
pub enum ValidTransformerBackend {
|
||||
LlamaCPP,
|
||||
PostgresML,
|
||||
}
|
||||
|
||||
// TODO: Review this for real lol
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct FIM {
|
||||
prefix: String,
|
||||
middle: String,
|
||||
suffix: String,
|
||||
}
|
||||
|
||||
// TODO: Add some default things
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct MaxNewTokens {
|
||||
pub completion: usize,
|
||||
pub generation: usize,
|
||||
}
|
||||
|
||||
impl Default for MaxNewTokens {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
completion: DEFAULT_MAX_COMPLETION_TOKENS,
|
||||
generation: DEFAULT_MAX_GENERATION_TOKENS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct ValidMemoryConfiguration {
|
||||
file_store: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct ModelGGUF {
|
||||
repository: String,
|
||||
name: String,
|
||||
// Fill in the middle support
|
||||
fim: Option<FIM>,
|
||||
// The maximum number of new tokens to generate
|
||||
#[serde(default)]
|
||||
max_new_tokens: MaxNewTokens,
|
||||
// Kwargs passed to LlamaCPP
|
||||
#[serde(flatten)]
|
||||
kwargs: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct ValidMacTransformerConfiguration {
|
||||
model_gguf: Option<ModelGGUF>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct ValidLinuxTransformerConfiguration {
|
||||
model_gguf: Option<ModelGGUF>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct ValidConfiguration {
|
||||
memory: ValidMemoryConfiguration,
|
||||
// TODO: Add renam here
|
||||
#[cfg(target_os = "macos")]
|
||||
#[serde(alias = "macos")]
|
||||
transformer: ValidMacTransformerConfiguration,
|
||||
#[cfg(target_os = "linux")]
|
||||
#[serde(alias = "linux")]
|
||||
transformer: ValidLinuxTransformerConfiguration,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Configuration {
|
||||
valid_config: ValidConfiguration,
|
||||
}
|
||||
|
||||
impl Configuration {
|
||||
pub fn new(mut args: Value) -> Result<Self> {
|
||||
let configuration_args = args
|
||||
.as_object_mut()
|
||||
.context("Server configuration must be a JSON object")?
|
||||
.remove("initializationOptions")
|
||||
.unwrap_or_default();
|
||||
let valid_args: ValidConfiguration = serde_json::from_value(configuration_args)?;
|
||||
// TODO: Make sure they only specified one model or something ya know
|
||||
Ok(Self {
|
||||
valid_config: valid_args,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_memory_backend(&self) -> Result<ValidMemoryBackend> {
|
||||
if self.valid_config.memory.file_store.is_some() {
|
||||
Ok(ValidMemoryBackend::FileStore)
|
||||
} else {
|
||||
anyhow::bail!("Invalid memory configuration")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_transformer_backend(&self) -> Result<ValidTransformerBackend> {
|
||||
if self.valid_config.transformer.model_gguf.is_some() {
|
||||
Ok(ValidTransformerBackend::LlamaCPP)
|
||||
} else {
|
||||
anyhow::bail!("Invalid model configuration")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_maximum_context_length(&self) -> usize {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
model_gguf
|
||||
.kwargs
|
||||
.get("n_ctx")
|
||||
.map(|v| {
|
||||
v.as_u64()
|
||||
.map(|u| u as usize)
|
||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
||||
})
|
||||
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_max_new_tokens(&self) -> &MaxNewTokens {
|
||||
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
|
||||
&model_gguf.max_new_tokens
|
||||
} else {
|
||||
panic!("We currently only support gguf models using llama cpp")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_fim(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn custom_mac_gguf_model() {
|
||||
let args = json!({
|
||||
"initializationOptions": {
|
||||
"memory": {
|
||||
"file_store": {}
|
||||
},
|
||||
"macos": {
|
||||
"model_gguf": {
|
||||
"repository": "deepseek-coder-6.7b-base",
|
||||
"name": "Q4_K_M.gguf",
|
||||
"max_new_tokens": {
|
||||
"completion": 32,
|
||||
"generation": 256,
|
||||
},
|
||||
"n_ctx": 2048,
|
||||
"n_threads": 8,
|
||||
"n_gpu_layers": 35,
|
||||
"chat_template": "",
|
||||
}
|
||||
},
|
||||
}
|
||||
});
|
||||
let _ = Configuration::new(args).unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use lsp_types::{PartialResultParams, ProgressToken, TextDocumentPositionParams};
|
||||
use lsp_types::{ProgressToken, TextDocumentPositionParams};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub enum GenerateStream {}
|
||||
|
||||
128
src/main.rs
128
src/main.rs
@@ -1,31 +1,28 @@
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::Result;
|
||||
|
||||
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId};
|
||||
use lsp_types::{
|
||||
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
|
||||
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
use pyo3::prelude::*;
|
||||
use ropey::Rope;
|
||||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, sync::Arc, thread};
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
mod configuration;
|
||||
mod custom_requests;
|
||||
mod memory_backends;
|
||||
mod transformer_backends;
|
||||
mod utils;
|
||||
mod worker;
|
||||
|
||||
use configuration::Configuration;
|
||||
use custom_requests::generate::Generate;
|
||||
use worker::{CompletionRequest, GenerateRequest, WorkerRequest};
|
||||
use memory_backends::MemoryBackend;
|
||||
use transformer_backends::TransformerBackend;
|
||||
use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
|
||||
|
||||
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
|
||||
|
||||
pub static PY_MODULE: Lazy<Result<Py<PyAny>>> = Lazy::new(|| {
|
||||
pyo3::Python::with_gil(|py| -> Result<Py<PyAny>> {
|
||||
let src = include_str!("python/transformers.py");
|
||||
Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into())
|
||||
})
|
||||
});
|
||||
|
||||
// Taken directly from: https://github.com/rust-lang/rust-analyzer
|
||||
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
|
||||
notification.method == N::METHOD
|
||||
@@ -52,55 +49,47 @@ fn main() -> Result<()> {
|
||||
)),
|
||||
..Default::default()
|
||||
})?;
|
||||
let initialization_params = connection.initialize(server_capabilities)?;
|
||||
let initialization_args = connection.initialize(server_capabilities)?;
|
||||
|
||||
// Activate the python venv
|
||||
Python::with_gil(|py| -> Result<()> {
|
||||
let activate: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "activate_venv")?;
|
||||
|
||||
activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?;
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
main_loop(connection, initialization_params)?;
|
||||
main_loop(connection, initialization_args)?;
|
||||
io_threads.join()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Params {}
|
||||
|
||||
// This main loop is tricky
|
||||
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
|
||||
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
|
||||
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
let _params: Params = serde_json::from_value(params)?;
|
||||
// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations
|
||||
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let args = Configuration::new(args)?;
|
||||
|
||||
// Set the model
|
||||
Python::with_gil(|py| -> Result<()> {
|
||||
let activate: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "set_model")?;
|
||||
activate.call1(py, ("",))?;
|
||||
Ok(())
|
||||
})?;
|
||||
// Set the transformer_backend
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
|
||||
transformer_backend.init()?;
|
||||
|
||||
// Prep variables
|
||||
// Set the memory_backend
|
||||
let memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>> =
|
||||
Arc::new(Mutex::new(args.clone().try_into()?));
|
||||
|
||||
// Wrap the connection for sharing between threads
|
||||
let connection = Arc::new(connection);
|
||||
let mut file_map: HashMap<String, Rope> = HashMap::new();
|
||||
|
||||
// How we communicate between the worker and receiver threads
|
||||
let last_worker_request = Arc::new(Mutex::new(None));
|
||||
|
||||
// Thread local variables
|
||||
let thread_memory_backend = memory_backend.clone();
|
||||
let thread_last_worker_request = last_worker_request.clone();
|
||||
let thread_connection = connection.clone();
|
||||
// TODO: Pass some backend into here
|
||||
thread::spawn(move || {
|
||||
worker::run(thread_last_worker_request, thread_connection);
|
||||
Worker::new(
|
||||
transformer_backend,
|
||||
thread_memory_backend,
|
||||
thread_last_worker_request,
|
||||
thread_connection,
|
||||
)
|
||||
.run();
|
||||
});
|
||||
|
||||
for msg in &connection.receiver {
|
||||
@@ -115,41 +104,30 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
if request_is::<Completion>(&req) {
|
||||
match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
eprintln!("******{:?}********", id);
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = CompletionRequest::new(id, params, rope);
|
||||
let completion_request = CompletionRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::Completion(completion_request));
|
||||
}
|
||||
Err(err) => panic!("{err:?}"),
|
||||
Err(err) => eprintln!("{err:?}"),
|
||||
}
|
||||
} else if request_is::<Generate>(&req) {
|
||||
match cast::<Generate>(req) {
|
||||
Ok((id, params)) => {
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = GenerateRequest::new(id, params, rope);
|
||||
let completion_request = GenerateRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::Generate(completion_request));
|
||||
}
|
||||
Err(err) => panic!("{err:?}"),
|
||||
Err(err) => eprintln!("{err:?}"),
|
||||
}
|
||||
} else if request_is::<GenerateStream>(&req) {
|
||||
match cast::<GenerateStream>(req) {
|
||||
Ok((id, params)) => {
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let mut lcr = last_worker_request.lock();
|
||||
let completion_request = GenerateStreamRequest::new(id, params, rope);
|
||||
let completion_request = GenerateStreamRequest::new(id, params);
|
||||
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
|
||||
}
|
||||
Err(err) => panic!("{err:?}"),
|
||||
Err(err) => eprintln!("{err:?}"),
|
||||
}
|
||||
} else {
|
||||
eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
|
||||
@@ -158,33 +136,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
Message::Notification(not) => {
|
||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
let rope = Rope::from_str(¶ms.text_document.text);
|
||||
file_map.insert(params.text_document.uri.to_string(), rope);
|
||||
memory_backend.lock().opened_text_document(params)?;
|
||||
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
let rope = file_map
|
||||
.get_mut(params.text_document.uri.as_str())
|
||||
.context("Error trying to get file that does not exist")?;
|
||||
for change in params.content_changes {
|
||||
// If range is ommitted, text is the new text of the document
|
||||
if let Some(range) = change.range {
|
||||
let start_index = rope.line_to_char(range.start.line as usize)
|
||||
+ range.start.character as usize;
|
||||
let end_index = rope.line_to_char(range.end.line as usize)
|
||||
+ range.end.character as usize;
|
||||
rope.remove(start_index..end_index);
|
||||
rope.insert(start_index, &change.text);
|
||||
} else {
|
||||
*rope = Rope::from_str(&change.text);
|
||||
}
|
||||
}
|
||||
memory_backend.lock().changed_text_document(params)?;
|
||||
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||
file_map.insert(file_rename.new_uri, rope);
|
||||
}
|
||||
}
|
||||
memory_backend.lock().renamed_file(params)?;
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
|
||||
98
src/memory_backends/file_store.rs
Normal file
98
src/memory_backends/file_store.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use anyhow::Context;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use ropey::Rope;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::configuration::Configuration;
|
||||
|
||||
use super::MemoryBackend;
|
||||
|
||||
pub struct FileStore {
|
||||
configuration: Configuration,
|
||||
file_map: HashMap<String, Rope>,
|
||||
}
|
||||
|
||||
impl FileStore {
|
||||
pub fn new(configuration: Configuration) -> Self {
|
||||
Self {
|
||||
configuration,
|
||||
file_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryBackend for FileStore {
|
||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
|
||||
if self.configuration.supports_fim() {
|
||||
// We will want to have some kind of infill support we add
|
||||
// rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// rope.insert(0, "<|fim_start|>");
|
||||
// rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// let prompt = rope.to_string();
|
||||
unimplemented!()
|
||||
} else {
|
||||
// Convert rope to correct prompt for llm
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
|
||||
let start = cursor_index
|
||||
.checked_sub(self.configuration.get_maximum_context_length())
|
||||
.unwrap_or(0);
|
||||
eprintln!("############ {start} - {cursor_index} #############");
|
||||
|
||||
Ok(rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?
|
||||
.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn opened_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let rope = Rope::from_str(¶ms.text_document.text);
|
||||
self.file_map
|
||||
.insert(params.text_document.uri.to_string(), rope);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn changed_text_document(
|
||||
&mut self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.get_mut(params.text_document.uri.as_str())
|
||||
.context("Error trying to get file that does not exist")?;
|
||||
for change in params.content_changes {
|
||||
// If range is ommitted, text is the new text of the document
|
||||
if let Some(range) = change.range {
|
||||
let start_index =
|
||||
rope.line_to_char(range.start.line as usize) + range.start.character as usize;
|
||||
let end_index =
|
||||
rope.line_to_char(range.end.line as usize) + range.end.character as usize;
|
||||
rope.remove(start_index..end_index);
|
||||
rope.insert(start_index, &change.text);
|
||||
} else {
|
||||
*rope = Rope::from_str(&change.text);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
|
||||
self.file_map.insert(file_rename.new_uri, rope);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
31
src/memory_backends/mod.rs
Normal file
31
src/memory_backends/mod.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use lsp_types::{
|
||||
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
||||
TextDocumentPositionParams,
|
||||
};
|
||||
|
||||
use crate::configuration::{Configuration, ValidMemoryBackend};
|
||||
|
||||
pub mod file_store;
|
||||
|
||||
pub trait MemoryBackend {
|
||||
fn init(&self) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
||||
match configuration.get_memory_backend()? {
|
||||
ValidMemoryBackend::FileStore => {
|
||||
Ok(Box::new(file_store::FileStore::new(configuration)))
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
|
||||
model = None
|
||||
|
||||
|
||||
def activate_venv(venv):
|
||||
if sys.platform in ('win32', 'win64', 'cygwin'):
|
||||
activate_this = os.path.join(venv, 'Scripts', 'activate_this.py')
|
||||
else:
|
||||
activate_this = os.path.join(venv, 'bin', 'activate_this.py')
|
||||
|
||||
if os.path.exists(activate_this):
|
||||
exec(open(activate_this).read(), dict(__file__=activate_this))
|
||||
return True
|
||||
else:
|
||||
print(f"Virtualenv not found: {venv}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def set_model(filler):
|
||||
global model
|
||||
model = Llama(
|
||||
# model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first
|
||||
model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first
|
||||
n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources
|
||||
n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance
|
||||
n_gpu_layers=35 # The number of layers to offload to GPU, if you have GPU acceleration available
|
||||
)
|
||||
|
||||
|
||||
def transform(input):
|
||||
# Simple inference example
|
||||
output = model(
|
||||
input, # Prompt
|
||||
max_tokens=32, # Generate up to 512 tokens
|
||||
stop=["<|EOT|>"], # Example stop token - not necessarily correct for this specific model! Please check before using.
|
||||
echo=False # Whether to echo the prompt
|
||||
)
|
||||
return output["choices"][0]["text"]
|
||||
88
src/transformer_backends/llama_cpp/mod.rs
Normal file
88
src/transformer_backends/llama_cpp/mod.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use crate::{
|
||||
configuration::Configuration,
|
||||
worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
};
|
||||
|
||||
use super::TransformerBackend;
|
||||
use once_cell::sync::Lazy;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
pub static PY_MODULE: Lazy<anyhow::Result<Py<PyAny>>> = Lazy::new(|| {
|
||||
pyo3::Python::with_gil(|py| -> anyhow::Result<Py<PyAny>> {
|
||||
let src = include_str!("python/transformers.py");
|
||||
Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into())
|
||||
})
|
||||
});
|
||||
|
||||
pub struct LlamaCPP {
|
||||
configuration: Configuration,
|
||||
}
|
||||
|
||||
impl LlamaCPP {
|
||||
pub fn new(configuration: Configuration) -> Self {
|
||||
Self { configuration }
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for LlamaCPP {
|
||||
fn init(&self) -> anyhow::Result<()> {
|
||||
// Activate the python venv
|
||||
Python::with_gil(|py| -> anyhow::Result<()> {
|
||||
let activate: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "activate_venv")?;
|
||||
activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?;
|
||||
Ok(())
|
||||
})?;
|
||||
// Set the model
|
||||
Python::with_gil(|py| -> anyhow::Result<()> {
|
||||
let activate: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "set_model")?;
|
||||
activate.call1(py, ("",))?;
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse> {
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
|
||||
Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt, max_new_tokens))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.map(|insert_text| DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse> {
|
||||
let max_new_tokens = self.configuration.get_max_new_tokens().generation;
|
||||
Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt, max_new_tokens))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.map(|generated_text| DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
Ok(DoGenerateStreamResponse {
|
||||
generated_text: "".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
45
src/transformer_backends/llama_cpp/python/transformers.py
Normal file
45
src/transformer_backends/llama_cpp/python/transformers.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
|
||||
model = None
|
||||
|
||||
|
||||
def activate_venv(venv):
|
||||
if sys.platform in ("win32", "win64", "cygwin"):
|
||||
activate_this = os.path.join(venv, "Scripts", "activate_this.py")
|
||||
else:
|
||||
activate_this = os.path.join(venv, "bin", "activate_this.py")
|
||||
|
||||
if os.path.exists(activate_this):
|
||||
exec(open(activate_this).read(), dict(__file__=activate_this))
|
||||
return True
|
||||
else:
|
||||
print(f"Virtualenv not found: {venv}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def set_model(filler):
|
||||
global model
|
||||
model = Llama(
|
||||
# model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first
|
||||
model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first
|
||||
n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources
|
||||
n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance
|
||||
n_gpu_layers=35, # The number of layers to offload to GPU, if you have GPU acceleration available
|
||||
)
|
||||
|
||||
|
||||
def transform(input, max_tokens):
|
||||
# Simple inference example
|
||||
output = model(
|
||||
input, # Prompt
|
||||
max_tokens=max_tokens, # Generate up to max tokens
|
||||
# stop=[
|
||||
# "<|EOT|>"
|
||||
# ], # Example stop token - not necessarily correct for this specific model! Please check before using.
|
||||
echo=False, # Whether to echo the prompt
|
||||
)
|
||||
return output["choices"][0]["text"]
|
||||
32
src/transformer_backends/mod.rs
Normal file
32
src/transformer_backends/mod.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
configuration::{Configuration, ValidTransformerBackend},
|
||||
worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateRequest,
|
||||
GenerateStreamRequest,
|
||||
},
|
||||
};
|
||||
|
||||
pub mod llama_cpp;
|
||||
|
||||
pub trait TransformerBackend {
|
||||
fn init(&self) -> anyhow::Result<()>;
|
||||
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse>;
|
||||
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse>;
|
||||
fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse>;
|
||||
}
|
||||
|
||||
impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
|
||||
match configuration.get_transformer_backend()? {
|
||||
ValidTransformerBackend::LlamaCPP => {
|
||||
Ok(Box::new(llama_cpp::LlamaCPP::new(configuration)))
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
15
src/utils.rs
Normal file
15
src/utils.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use lsp_server::ResponseError;
|
||||
|
||||
pub trait ToResponseError {
|
||||
fn to_response_error(&self, code: i32) -> ResponseError;
|
||||
}
|
||||
|
||||
impl ToResponseError for anyhow::Error {
|
||||
fn to_response_error(&self, code: i32) -> ResponseError {
|
||||
ResponseError {
|
||||
code: -32603,
|
||||
message: self.to_string(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
187
src/worker.rs
Normal file
187
src/worker.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use lsp_server::{Connection, Message, RequestId, Response};
|
||||
use lsp_types::{
|
||||
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
|
||||
Position, Range, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use crate::custom_requests::generate_stream::{GenerateStreamParams, GenerateStreamResult};
|
||||
use crate::memory_backends::MemoryBackend;
|
||||
use crate::transformer_backends::TransformerBackend;
|
||||
use crate::utils::ToResponseError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn new(id: RequestId, params: CompletionParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GenerateRequest {
|
||||
id: RequestId,
|
||||
params: GenerateParams,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
pub fn new(id: RequestId, params: GenerateParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GenerateStreamRequest {
|
||||
id: RequestId,
|
||||
params: GenerateStreamParams,
|
||||
}
|
||||
|
||||
impl GenerateStreamRequest {
|
||||
pub fn new(id: RequestId, params: GenerateStreamParams) -> Self {
|
||||
Self { id, params }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum WorkerRequest {
|
||||
Completion(CompletionRequest),
|
||||
Generate(GenerateRequest),
|
||||
GenerateStream(GenerateStreamRequest),
|
||||
}
|
||||
|
||||
pub struct DoCompletionResponse {
|
||||
pub insert_text: String,
|
||||
}
|
||||
|
||||
pub struct DoGenerateResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub struct DoGenerateStreamResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub struct Worker {
|
||||
transformer_backend: Box<dyn TransformerBackend>,
|
||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
pub fn new(
|
||||
transformer_backend: Box<dyn TransformerBackend>,
|
||||
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
|
||||
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
|
||||
connection: Arc<Connection>,
|
||||
) -> Self {
|
||||
Self {
|
||||
transformer_backend,
|
||||
memory_backend,
|
||||
last_worker_request,
|
||||
connection,
|
||||
}
|
||||
}
|
||||
|
||||
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self
|
||||
.memory_backend
|
||||
.lock()
|
||||
.build_prompt(&request.params.text_document_position)?;
|
||||
eprintln!("\n\n****************{}***************\n\n", prompt);
|
||||
let response = self.transformer_backend.do_completion(&prompt)?;
|
||||
eprintln!(
|
||||
"\n\n****************{}***************\n\n",
|
||||
response.insert_text
|
||||
);
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
response.insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {}", response.insert_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
|
||||
let prompt = self
|
||||
.memory_backend
|
||||
.lock()
|
||||
.build_prompt(&request.params.text_document_position)?;
|
||||
eprintln!("\n\n****************{}***************\n\n", prompt);
|
||||
let response = self.transformer_backend.do_generate(&prompt)?;
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(self) {
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let mut completion_request = self.last_worker_request.lock();
|
||||
std::mem::take(&mut *completion_request)
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => match self.do_completion(&request) {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::Generate(request) => match self.do_generate(&request) {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
},
|
||||
WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"),
|
||||
};
|
||||
self.connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending message");
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
use lsp_server::ResponseError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use super::CompletionRequest;
|
||||
use crate::PY_MODULE;
|
||||
|
||||
pub struct DoCompletionResponse {
|
||||
pub insert_text: String,
|
||||
pub filter_text: String,
|
||||
}
|
||||
|
||||
pub fn do_completion(request: &CompletionRequest) -> Result<DoCompletionResponse, ResponseError> {
|
||||
let filter_text = request
|
||||
.rope
|
||||
.get_line(request.params.text_document_position.position.line as usize)
|
||||
.ok_or(ResponseError {
|
||||
code: -32603, // Maybe we want a different error code here?
|
||||
message: "Error getting line in requested document".to_string(),
|
||||
data: None,
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
// Convert rope to correct prompt for llm
|
||||
let cursor_index = request
|
||||
.rope
|
||||
.line_to_char(request.params.text_document_position.position.line as usize)
|
||||
+ request.params.text_document_position.position.character as usize;
|
||||
|
||||
// We will want to have some kind of infill support we add
|
||||
// rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// rope.insert(0, "<|fim_start|>");
|
||||
// rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// let prompt = rope.to_string();
|
||||
|
||||
let prompt = request
|
||||
.rope
|
||||
.get_slice(0..cursor_index)
|
||||
.expect("Error getting rope slice")
|
||||
.to_string();
|
||||
|
||||
eprintln!("\n\n****{prompt}****\n\n");
|
||||
|
||||
Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.map(|insert_text| DoCompletionResponse {
|
||||
insert_text,
|
||||
filter_text,
|
||||
})
|
||||
.map_err(|e| ResponseError {
|
||||
code: -32603,
|
||||
message: e.to_string(),
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use lsp_server::ResponseError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use super::{GenerateRequest, GenerateStreamRequest};
|
||||
use crate::PY_MODULE;
|
||||
|
||||
pub struct DoGenerateResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub fn do_generate(request: &GenerateRequest) -> Result<DoGenerateResponse, ResponseError> {
|
||||
// Convert rope to correct prompt for llm
|
||||
let cursor_index = request
|
||||
.rope
|
||||
.line_to_char(request.params.text_document_position.position.line as usize)
|
||||
+ request.params.text_document_position.position.character as usize;
|
||||
|
||||
// We will want to have some kind of infill support we add
|
||||
// rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// rope.insert(0, "<|fim_start|>");
|
||||
// rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// let prompt = rope.to_string();
|
||||
|
||||
let prompt = request
|
||||
.rope
|
||||
.get_slice(0..cursor_index)
|
||||
.expect("Error getting rope slice")
|
||||
.to_string();
|
||||
|
||||
eprintln!("\n\n****{prompt}****\n\n");
|
||||
|
||||
Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.map(|generated_text| DoGenerateResponse { generated_text })
|
||||
.map_err(|e| ResponseError {
|
||||
code: -32603,
|
||||
message: e.to_string(),
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
use lsp_server::ResponseError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use super::{GenerateRequest, GenerateStreamRequest};
|
||||
use crate::PY_MODULE;
|
||||
|
||||
pub fn do_generate_stream(request: &GenerateStreamRequest) -> Result<(), ResponseError> {
|
||||
// Convert rope to correct prompt for llm
|
||||
// let cursor_index = request
|
||||
// .rope
|
||||
// .line_to_char(request.params.text_document_position.position.line as usize)
|
||||
// + request.params.text_document_position.position.character as usize;
|
||||
|
||||
// // We will want to have some kind of infill support we add
|
||||
// // rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// // rope.insert(0, "<|fim_start|>");
|
||||
// // rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// // let prompt = rope.to_string();
|
||||
|
||||
// let prompt = request
|
||||
// .rope
|
||||
// .get_slice(0..cursor_index)
|
||||
// .expect("Error getting rope slice")
|
||||
// .to_string();
|
||||
|
||||
// eprintln!("\n\n****{prompt}****\n\n");
|
||||
|
||||
// Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
// let transform: Py<PyAny> = PY_MODULE
|
||||
// .as_ref()
|
||||
// .map_err(anyhow::Error::msg)?
|
||||
// .getattr(py, "transform")?;
|
||||
|
||||
// let out: String = transform.call1(py, (prompt,))?.extract(py)?;
|
||||
// Ok(out)
|
||||
// })
|
||||
// .map(|generated_text| DoGenerateResponse { generated_text })
|
||||
// .map_err(|e| ResponseError {
|
||||
// code: -32603,
|
||||
// message: e.to_string(),
|
||||
// data: None,
|
||||
// })
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
use lsp_server::{Connection, Message, RequestId, Response};
|
||||
use lsp_types::{
|
||||
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
|
||||
Position, Range, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use ropey::Rope;
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
mod completion;
|
||||
mod generate;
|
||||
mod generate_stream;
|
||||
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use crate::custom_requests::generate_stream::{GenerateStreamParams, GenerateStreamResult};
|
||||
use completion::do_completion;
|
||||
use generate::do_generate;
|
||||
use generate_stream::do_generate_stream;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GenerateRequest {
|
||||
id: RequestId,
|
||||
params: GenerateParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
pub fn new(id: RequestId, params: GenerateParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GenerateStreamRequest {
|
||||
id: RequestId,
|
||||
params: GenerateStreamParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl GenerateStreamRequest {
|
||||
pub fn new(id: RequestId, params: GenerateStreamParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum WorkerRequest {
|
||||
Completion(CompletionRequest),
|
||||
Generate(GenerateRequest),
|
||||
GenerateStream(GenerateStreamRequest),
|
||||
}
|
||||
|
||||
pub fn run(last_worker_request: Arc<Mutex<Option<WorkerRequest>>>, connection: Arc<Connection>) {
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let mut completion_request = last_worker_request.lock();
|
||||
std::mem::take(&mut *completion_request)
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => match do_completion(&request) {
|
||||
Ok(response) => {
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
response.insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {}", response.insert_text),
|
||||
filter_text: Some(response.filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
|
||||
completion_text_edit,
|
||||
)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Response {
|
||||
id: request.id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e),
|
||||
},
|
||||
},
|
||||
WorkerRequest::Generate(request) => match do_generate(&request) {
|
||||
Ok(result) => {
|
||||
let result = GenerateResult {
|
||||
generated_text: result.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Response {
|
||||
id: request.id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e),
|
||||
},
|
||||
},
|
||||
WorkerRequest::GenerateStream(request) => match do_generate_stream(&request) {
|
||||
Ok(result) => {
|
||||
// let result = GenerateResult {
|
||||
// generated_text: result.generated_text,
|
||||
// };
|
||||
// let result = serde_json::to_value(&result).unwrap();
|
||||
let result = GenerateStreamResult {
|
||||
generated_text: "test".to_string(),
|
||||
partial_result_token: request.params.partial_result_token,
|
||||
};
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Response {
|
||||
id: request.id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e),
|
||||
},
|
||||
},
|
||||
};
|
||||
connection
|
||||
.sender
|
||||
.send(Message::Response(response.clone()))
|
||||
.expect("Error sending response");
|
||||
connection
|
||||
.sender
|
||||
.send(Message::Response(response.clone()))
|
||||
.expect("Error sending response");
|
||||
connection
|
||||
.sender
|
||||
.send(Message::Response(response.clone()))
|
||||
.expect("Error sending response");
|
||||
// connection
|
||||
// .sender
|
||||
// .send(Message::Response(Response {
|
||||
// id: response.id,
|
||||
// result: None,
|
||||
// error: None,
|
||||
// }))
|
||||
// .expect("Error sending message");
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
}
|
||||
18
test.json
Normal file
18
test.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"macos": {
|
||||
"model_gguf": {
|
||||
"repository": "deepseek-coder-6.7b-base",
|
||||
"name": "Q4_K_M.gguf",
|
||||
"fim": false,
|
||||
"n_ctx": 2048,
|
||||
"n_threads": 8,
|
||||
"n_gpu_layers": 35
|
||||
}
|
||||
},
|
||||
"linux": {
|
||||
"model_gptq": {
|
||||
"repository": "theblokesomething",
|
||||
"name": "some q5 or something"
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user