mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 23:14:28 +01:00
The beginning of something awesome
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
/models
|
||||||
1997
Cargo.lock
generated
Normal file
1997
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
Cargo.toml
Normal file
26
Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
[package]
|
||||||
|
name = "lsp-ai"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.75"
|
||||||
|
lsp-server = "0.7.4"
|
||||||
|
lsp-types = "0.94.1"
|
||||||
|
once_cell = "1.18.0"
|
||||||
|
parking_lot = "0.12.1"
|
||||||
|
ropey = "1.6.1"
|
||||||
|
serde = "1.0.190"
|
||||||
|
serde_json = "1.0.108"
|
||||||
|
# candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||||
|
# candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||||
|
# candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||||
|
candle-core = { path = "../candle/candle-core", version = "0.3.1", features = ["accelerate"] }
|
||||||
|
candle-nn = { path = "../candle/candle-nn", version = "0.3.1", features = ["accelerate"] }
|
||||||
|
candle-transformers = { path = "../candle/candle-transformers", version = "0.3.1", features = ["accelerate"] }
|
||||||
|
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
|
||||||
|
rand = "0.8.5"
|
||||||
|
tokenizers = "0.14.1"
|
||||||
|
|
||||||
3
run.sh
Executable file
3
run.sh
Executable file
@@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
/Users/silas/Projects/lsp-ai/target/release/lsp-ai
|
||||||
194
src/main.rs
Normal file
194
src/main.rs
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
use anyhow::Context;
|
||||||
|
use anyhow::Result;
|
||||||
|
use core::panic;
|
||||||
|
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
|
||||||
|
use lsp_types::{
|
||||||
|
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
|
||||||
|
CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, Position, Range,
|
||||||
|
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
|
||||||
|
};
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use serde::Deserialize;
|
||||||
|
// use pyo3::prelude::*;
|
||||||
|
// use pyo3::types::PyTuple;
|
||||||
|
use ropey::Rope;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
mod transformer;
|
||||||
|
|
||||||
|
static FILE_MAP: once_cell::sync::Lazy<Mutex<HashMap<String, Rope>>> =
|
||||||
|
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
|
// Taken directly from: https://github.com/rust-lang/rust-analyzer
|
||||||
|
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
|
||||||
|
notification.method == N::METHOD
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let (connection, io_threads) = Connection::stdio();
|
||||||
|
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||||
|
completion_provider: Some(CompletionOptions::default()),
|
||||||
|
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
|
||||||
|
TextDocumentSyncKind::INCREMENTAL,
|
||||||
|
)),
|
||||||
|
..Default::default()
|
||||||
|
})?;
|
||||||
|
let initialization_params = connection.initialize(server_capabilities)?;
|
||||||
|
main_loop(connection, initialization_params)?;
|
||||||
|
io_threads.join()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Params {
|
||||||
|
model: Option<String>,
|
||||||
|
model_file: Option<String>,
|
||||||
|
model_type: Option<String>,
|
||||||
|
device: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||||
|
let params: Params = serde_json::from_value(params)?;
|
||||||
|
let mut text_generation = transformer::build()?;
|
||||||
|
for msg in &connection.receiver {
|
||||||
|
match msg {
|
||||||
|
Message::Request(req) => {
|
||||||
|
if connection.handle_shutdown(&req)? {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
match cast::<Completion>(req) {
|
||||||
|
Ok((id, params)) => {
|
||||||
|
// Get rope
|
||||||
|
let file_map = FILE_MAP.lock();
|
||||||
|
let mut rope = file_map
|
||||||
|
.get(params.text_document_position.text_document.uri.as_str())
|
||||||
|
.context("Error file not found")?
|
||||||
|
.clone();
|
||||||
|
let filter_text = rope
|
||||||
|
.get_line(params.text_document_position.position.line as usize)
|
||||||
|
.context("Error getting line with ropey")?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// Convert rope to correct prompt for llm
|
||||||
|
let start_index = rope
|
||||||
|
.line_to_char(params.text_document_position.position.line as usize)
|
||||||
|
+ params.text_document_position.position.character as usize;
|
||||||
|
rope.insert(start_index, "<fim_suffix>");
|
||||||
|
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
|
||||||
|
let insert_text = text_generation.run(&prompt, 64)?;
|
||||||
|
|
||||||
|
// Create and return the completion
|
||||||
|
let completion_text_edit = TextEdit::new(
|
||||||
|
Range::new(
|
||||||
|
Position::new(
|
||||||
|
params.text_document_position.position.line,
|
||||||
|
params.text_document_position.position.character,
|
||||||
|
),
|
||||||
|
Position::new(
|
||||||
|
params.text_document_position.position.line,
|
||||||
|
params.text_document_position.position.character,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
insert_text.clone(),
|
||||||
|
);
|
||||||
|
let item = CompletionItem {
|
||||||
|
label: format!("ai - {insert_text}"),
|
||||||
|
filter_text: Some(filter_text),
|
||||||
|
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
|
||||||
|
completion_text_edit,
|
||||||
|
)),
|
||||||
|
kind: Some(CompletionItemKind::TEXT),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let completion_list = CompletionList {
|
||||||
|
is_incomplete: false,
|
||||||
|
items: vec![item],
|
||||||
|
};
|
||||||
|
let result = Some(CompletionResponse::List(completion_list));
|
||||||
|
let result = serde_json::to_value(&result).unwrap();
|
||||||
|
let resp = Response {
|
||||||
|
id,
|
||||||
|
result: Some(result),
|
||||||
|
error: None,
|
||||||
|
};
|
||||||
|
connection.sender.send(Message::Response(resp))?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
|
||||||
|
Err(ExtractError::MethodMismatch(req)) => req,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Message::Notification(not) => {
|
||||||
|
eprintln!("got notification: {not:?}");
|
||||||
|
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||||
|
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||||
|
let rope = Rope::from_str(¶ms.text_document.text);
|
||||||
|
let mut file_map = FILE_MAP.lock();
|
||||||
|
file_map.insert(params.text_document.uri.to_string(), rope);
|
||||||
|
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||||
|
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||||
|
let mut file_map = FILE_MAP.lock();
|
||||||
|
let rope = file_map
|
||||||
|
.get_mut(params.text_document.uri.as_str())
|
||||||
|
.context("Error trying to get file that does not exist")?;
|
||||||
|
for change in params.content_changes {
|
||||||
|
// If range is ommitted, text is the new text of the document
|
||||||
|
if let Some(range) = change.range {
|
||||||
|
let start_index = rope.line_to_char(range.start.line as usize)
|
||||||
|
+ range.start.character as usize;
|
||||||
|
let end_index = rope.line_to_char(range.end.line as usize)
|
||||||
|
+ range.end.character as usize;
|
||||||
|
rope.remove(start_index..end_index);
|
||||||
|
rope.insert(start_index, &change.text);
|
||||||
|
} else {
|
||||||
|
*rope = Rope::from_str(&change.text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||||
|
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||||
|
let mut file_map = FILE_MAP.lock();
|
||||||
|
for file_rename in params.files {
|
||||||
|
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||||
|
file_map.insert(file_rename.new_uri, rope);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
|
||||||
|
where
|
||||||
|
R: lsp_types::request::Request,
|
||||||
|
R::Params: serde::de::DeserializeOwned,
|
||||||
|
{
|
||||||
|
req.extract(R::METHOD)
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[cfg(test)]
|
||||||
|
// mod tests {
|
||||||
|
// use super::*;
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn test_lsp() -> Result<()> {
|
||||||
|
// let prompt = "def sum_two_numers(x: int, y:";
|
||||||
|
// let result = Python::with_gil(|py| -> Result<String> {
|
||||||
|
// let transform: Py<PyAny> = PY_MODULE
|
||||||
|
// .as_ref()
|
||||||
|
// .expect("Error getting python module")
|
||||||
|
// .getattr(py, "transform")
|
||||||
|
// .expect("Error getting transform");
|
||||||
|
|
||||||
|
// let output = transform
|
||||||
|
// .call1(py, PyTuple::new(py, &[prompt]))
|
||||||
|
// .expect("Error calling transform");
|
||||||
|
|
||||||
|
// Ok(output.extract(py).expect("Error extracting result"))
|
||||||
|
// })?;
|
||||||
|
// println!("\n\nTHE RESULT\n{:?}\n\n", result);
|
||||||
|
// Ok(())
|
||||||
|
// }
|
||||||
|
// }
|
||||||
103
src/transformer.rs
Normal file
103
src/transformer.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||||
|
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
pub struct TextGeneration {
|
||||||
|
model: GPTBigCode,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
fn new(
|
||||||
|
model: GPTBigCode,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
|
||||||
|
eprintln!("Starting to generate tokens");
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let mut new_tokens = vec![];
|
||||||
|
let mut outputs = vec![];
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
|
||||||
|
(1, tokens.len().saturating_sub(1))
|
||||||
|
} else {
|
||||||
|
(tokens.len(), 0)
|
||||||
|
};
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, past_len)?;
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
new_tokens.push(next_token);
|
||||||
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
|
outputs.push(token);
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
self.model.clear_cache();
|
||||||
|
eprintln!(
|
||||||
|
"GENERATED {} tokens in {} seconds",
|
||||||
|
outputs.len(),
|
||||||
|
dt.as_secs()
|
||||||
|
);
|
||||||
|
Ok(outputs.join(""))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build() -> Result<TextGeneration> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
eprintln!("Loading in model");
|
||||||
|
let api = ApiBuilder::new()
|
||||||
|
.with_token(Some(std::env::var("HF_TOKEN")?.to_string()))
|
||||||
|
.build()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
"bigcode/starcoderbase-1b".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
"main".to_string(),
|
||||||
|
));
|
||||||
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
|
let filenames = ["model.safetensors"]
|
||||||
|
.iter()
|
||||||
|
.map(|f| repo.get(f))
|
||||||
|
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let device = Device::Cpu;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
|
let config = Config::starcoder_1b();
|
||||||
|
let model = GPTBigCode::load(vb, config)?;
|
||||||
|
eprintln!("loaded the model in {:?}", start.elapsed());
|
||||||
|
Ok(TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
0,
|
||||||
|
Some(0.85),
|
||||||
|
None,
|
||||||
|
&device,
|
||||||
|
))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user