Moving towards trait system for models and cleaned up a ton of stuff

This commit is contained in:
SilasMarvin
2023-11-24 15:10:45 -08:00
parent 8e66c66326
commit ac09f89da7
6 changed files with 268 additions and 139 deletions

79
Cargo.lock generated
View File

@@ -3,10 +3,13 @@
version = 3
[[package]]
name = "accelerate-src"
version = "0.3.2"
name = "addr2line"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
@@ -76,6 +79,9 @@ name = "anyhow"
version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
dependencies = [
"backtrace",
]
[[package]]
name = "autocfg"
@@ -83,6 +89,21 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "backtrace"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
]
[[package]]
name = "base64"
version = "0.13.1"
@@ -137,11 +158,11 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
name = "candle-core"
version = "0.3.1"
dependencies = [
"accelerate-src",
"byteorder",
"candle-kernels",
"cudarc",
"gemm",
"half",
"libc",
"memmap2",
"num-traits",
"num_cpus",
@@ -154,11 +175,19 @@ dependencies = [
"zip",
]
[[package]]
name = "candle-kernels"
version = "0.3.1"
dependencies = [
"anyhow",
"glob",
"rayon",
]
[[package]]
name = "candle-nn"
version = "0.3.1"
dependencies = [
"accelerate-src",
"candle-core",
"half",
"num-traits",
@@ -172,7 +201,6 @@ dependencies = [
name = "candle-transformers"
version = "0.3.1"
dependencies = [
"accelerate-src",
"byteorder",
"candle-core",
"candle-nn",
@@ -334,6 +362,15 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cudarc"
version = "0.9.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552"
dependencies = [
"half",
]
[[package]]
name = "darling"
version = "0.14.4"
@@ -636,6 +673,18 @@ dependencies = [
"wasi",
]
[[package]]
name = "gimli"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "half"
version = "2.3.1"
@@ -793,7 +842,6 @@ dependencies = [
"hf-hub",
"lsp-server",
"lsp-types",
"once_cell",
"parking_lot",
"rand",
"ropey",
@@ -968,6 +1016,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.18.0"
@@ -1306,6 +1363,12 @@ dependencies = [
"str_indices",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustix"
version = "0.38.25"

View File

@@ -9,18 +9,20 @@ edition = "2021"
anyhow = "1.0.75"
lsp-server = "0.7.4"
lsp-types = "0.94.1"
once_cell = "1.18.0"
parking_lot = "0.12.1"
ropey = "1.6.1"
serde = "1.0.190"
serde_json = "1.0.108"
# candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
# candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
# candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
candle-core = { path = "../candle/candle-core", version = "0.3.1", features = ["accelerate"] }
candle-nn = { path = "../candle/candle-nn", version = "0.3.1", features = ["accelerate"] }
candle-transformers = { path = "../candle/candle-transformers", version = "0.3.1", features = ["accelerate"] }
candle-core = { path = "../candle/candle-core" }
candle-nn = { path = "../candle/candle-nn" }
candle-transformers = { path = "../candle/candle-transformers" }
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
rand = "0.8.5"
tokenizers = "0.14.1"
parking_lot = "0.12.1"
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

2
run.sh
View File

@@ -1,3 +1,3 @@
#!/usr/bin/env bash
/Users/silas/Projects/lsp-ai/target/release/lsp-ai
/home/silas/Projects/lsp-ai/target/release/lsp-ai

View File

@@ -4,26 +4,32 @@ use core::panic;
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
use lsp_types::{
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, Position, Range,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
CompletionParams, CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
Position, Range, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
};
use parking_lot::Mutex;
use serde::Deserialize;
// use pyo3::prelude::*;
// use pyo3::types::PyTuple;
use ropey::Rope;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use std::thread;
mod transformer;
static FILE_MAP: once_cell::sync::Lazy<Mutex<HashMap<String, Rope>>> =
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
mod models;
use models::{Model, ModelParams};
// Taken directly from: https://github.com/rust-lang/rust-analyzer
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD
}
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
where
R: lsp_types::request::Request,
R::Params: serde::de::DeserializeOwned,
{
req.extract(R::METHOD)
}
fn main() -> Result<()> {
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(&ServerCapabilities {
@@ -41,15 +47,107 @@ fn main() -> Result<()> {
#[derive(Deserialize)]
struct Params {
model: Option<String>,
model_file: Option<String>,
model_type: Option<String>,
device: Option<String>,
// We may want to put other non-model related parameters here in the future
model_params: Option<ModelParams>,
}
struct CompletionRequest {
id: RequestId,
params: CompletionParams,
rope: Rope,
}
impl CompletionRequest {
fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
// This main loop is tricky
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
let params: Params = serde_json::from_value(params)?;
let mut text_generation = transformer::build()?;
// Prep variables
let connection = Arc::new(connection);
let mut file_map: HashMap<String, Rope> = HashMap::new();
// How we communicate between the worker and receiver threads
let last_completion_request = Arc::new(Mutex::new(None));
// Thread local variables
let thread_last_completion_request = last_completion_request.clone();
let thread_connection = connection.clone();
// We need to allow unreachabel to be able to use the question mark operators here
// We could probably restructure this to not require it
#[allow(unreachable_code)]
thread::spawn(move || {
// Build the model from the params
let mut model: Box<dyn Model> = params.model_params.unwrap_or_default().try_into()?;
loop {
// I think we need this drop, not 100% sure though
let mut completion_request = thread_last_completion_request.lock();
let params = std::mem::take(&mut *completion_request);
drop(completion_request);
if let Some(CompletionRequest {
id,
params,
mut rope,
}) = params
{
let filter_text = rope
.get_line(params.text_document_position.position.line as usize)
.context("Error getting line with ropey")?
.to_string();
// Convert rope to correct prompt for llm
let start_index = rope
.line_to_char(params.text_document_position.position.line as usize)
+ params.text_document_position.position.character as usize;
rope.insert(start_index, "<fim_suffix>");
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
let insert_text = model.run(&prompt)?;
// Create and return the completion
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
),
insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {insert_text}"),
filter_text: Some(filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
let resp = Response {
id,
result: Some(result),
error: None,
};
thread_connection.sender.send(Message::Response(resp))?;
}
thread::sleep(std::time::Duration::from_millis(5));
}
anyhow::Ok(())
});
for msg in &connection.receiver {
match msg {
Message::Request(req) => {
@@ -59,59 +157,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
match cast::<Completion>(req) {
Ok((id, params)) => {
// Get rope
let file_map = FILE_MAP.lock();
let mut rope = file_map
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let filter_text = rope
.get_line(params.text_document_position.position.line as usize)
.context("Error getting line with ropey")?
.to_string();
// Convert rope to correct prompt for llm
let start_index = rope
.line_to_char(params.text_document_position.position.line as usize)
+ params.text_document_position.position.character as usize;
rope.insert(start_index, "<fim_suffix>");
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
let insert_text = text_generation.run(&prompt, 64)?;
// Create and return the completion
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
),
insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {insert_text}"),
filter_text: Some(filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
completion_text_edit,
)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
let resp = Response {
id,
result: Some(result),
error: None,
};
connection.sender.send(Message::Response(resp))?;
// Update the last CompletionRequest
let mut lcr = last_completion_request.lock();
*lcr = Some(CompletionRequest::new(id, params, rope));
continue;
}
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
@@ -123,11 +175,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
if notification_is::<lsp_types::notification::DidOpenTextDocument>(&not) {
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
let rope = Rope::from_str(&params.text_document.text);
let mut file_map = FILE_MAP.lock();
file_map.insert(params.text_document.uri.to_string(), rope);
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(&not) {
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
let mut file_map = FILE_MAP.lock();
let rope = file_map
.get_mut(params.text_document.uri.as_str())
.context("Error trying to get file that does not exist")?;
@@ -146,7 +196,6 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
}
} else if notification_is::<lsp_types::notification::DidRenameFiles>(&not) {
let params: RenameFilesParams = serde_json::from_value(not.params)?;
let mut file_map = FILE_MAP.lock();
for file_rename in params.files {
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
file_map.insert(file_rename.new_uri, rope);
@@ -159,36 +208,3 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
}
Ok(())
}
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
where
R: lsp_types::request::Request,
R::Params: serde::de::DeserializeOwned,
{
req.extract(R::METHOD)
}
// #[cfg(test)]
// mod tests {
// use super::*;
// #[test]
// fn test_lsp() -> Result<()> {
// let prompt = "def sum_two_numers(x: int, y:";
// let result = Python::with_gil(|py| -> Result<String> {
// let transform: Py<PyAny> = PY_MODULE
// .as_ref()
// .expect("Error getting python module")
// .getattr(py, "transform")
// .expect("Error getting transform");
// let output = transform
// .call1(py, PyTuple::new(py, &[prompt]))
// .expect("Error calling transform");
// Ok(output.extract(py).expect("Error extracting result"))
// })?;
// println!("\n\nTHE RESULT\n{:?}\n\n", result);
// Ok(())
// }
// }

32
src/models/mod.rs Normal file
View File

@@ -0,0 +1,32 @@
use anyhow::Result;
use serde::Deserialize;
mod starcoder;
pub trait Model {
fn run(&mut self, prompt: &str) -> Result<String>;
}
#[derive(Deserialize, Default)]
pub struct ModelParams {
model: Option<String>,
model_file: Option<String>,
model_type: Option<String>,
max_length: Option<usize>,
}
impl TryFrom<ModelParams> for Box<dyn Model> {
type Error = anyhow::Error;
fn try_from(value: ModelParams) -> Result<Self> {
let model_type = value.model_type.unwrap_or("starcoder".to_string());
let max_length = value.max_length.unwrap_or(12);
Ok(Box::new(match model_type.as_str() {
"starcoder" => starcoder::build_model(value.model, value.model_file, max_length)?,
_ => anyhow::bail!(
"Model type: {} not supported. Feel free to make a pr or create a github issue.",
model_type
),
}))
}
}

View File

@@ -6,32 +6,16 @@ use candle_transformers::models::bigcode::{Config, GPTBigCode};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
pub struct TextGeneration {
pub struct Model {
model: GPTBigCode,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
max_length: usize,
}
impl TextGeneration {
fn new(
model: GPTBigCode,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
device: device.clone(),
}
}
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
impl super::Model for Model {
fn run(&mut self, prompt: &str) -> Result<String> {
eprintln!("Starting to generate tokens");
let mut tokens = self
.tokenizer
@@ -42,7 +26,7 @@ impl TextGeneration {
let mut new_tokens = vec![];
let mut outputs = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
for index in 0..self.max_length {
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
(1, tokens.len().saturating_sub(1))
} else {
@@ -62,15 +46,40 @@ impl TextGeneration {
let dt = start_gen.elapsed();
self.model.clear_cache();
eprintln!(
"GENERATED {} tokens in {} seconds",
"GENERATED {} tokens in {} milliseconds",
outputs.len(),
dt.as_secs()
dt.as_millis()
);
Ok(outputs.join(""))
}
}
pub fn build() -> Result<TextGeneration> {
impl Model {
fn new(
model: GPTBigCode,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
device: &Device,
max_length: usize,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
device: device.clone(),
max_length,
}
}
}
pub fn build_model(
model: Option<String>,
model_file: Option<String>,
max_length: usize,
) -> Result<Model> {
let start = std::time::Instant::now();
eprintln!("Loading in model");
let api = ApiBuilder::new()
@@ -87,17 +96,24 @@ pub fn build() -> Result<TextGeneration> {
.map(|f| repo.get(f))
.collect::<std::result::Result<Vec<_>, _>>()?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
// Set the device
#[cfg(feature = "cuda")]
let device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let device = Device::Cpu;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?;
eprintln!("loaded the model in {:?}", start.elapsed());
Ok(TextGeneration::new(
Ok(Model::new(
model,
tokenizer,
0,
Some(0.85),
None,
&device,
max_length,
))
}