mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2026-01-10 02:04:18 +01:00
Moving towards trait system for models and cleaned up a ton of stuff
This commit is contained in:
79
Cargo.lock
generated
79
Cargo.lock
generated
@@ -3,10 +3,13 @@
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "accelerate-src"
|
||||
version = "0.3.2"
|
||||
name = "addr2line"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d"
|
||||
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
|
||||
dependencies = [
|
||||
"gimli",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
@@ -76,6 +79,9 @@ name = "anyhow"
|
||||
version = "1.0.75"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
@@ -83,6 +89,21 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837"
|
||||
dependencies = [
|
||||
"addr2line",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"miniz_oxide",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
@@ -137,11 +158,11 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
name = "candle-core"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"byteorder",
|
||||
"candle-kernels",
|
||||
"cudarc",
|
||||
"gemm",
|
||||
"half",
|
||||
"libc",
|
||||
"memmap2",
|
||||
"num-traits",
|
||||
"num_cpus",
|
||||
@@ -154,11 +175,19 @@ dependencies = [
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"glob",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-nn"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"candle-core",
|
||||
"half",
|
||||
"num-traits",
|
||||
@@ -172,7 +201,6 @@ dependencies = [
|
||||
name = "candle-transformers"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"byteorder",
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
@@ -334,6 +362,15 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.9.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552"
|
||||
dependencies = [
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.14.4"
|
||||
@@ -636,6 +673,18 @@ dependencies = [
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.28.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.3.1"
|
||||
@@ -793,7 +842,6 @@ dependencies = [
|
||||
"hf-hub",
|
||||
"lsp-server",
|
||||
"lsp-types",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"rand",
|
||||
"ropey",
|
||||
@@ -968,6 +1016,15 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.32.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.18.0"
|
||||
@@ -1306,6 +1363,12 @@ dependencies = [
|
||||
"str_indices",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.25"
|
||||
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -9,18 +9,20 @@ edition = "2021"
|
||||
anyhow = "1.0.75"
|
||||
lsp-server = "0.7.4"
|
||||
lsp-types = "0.94.1"
|
||||
once_cell = "1.18.0"
|
||||
parking_lot = "0.12.1"
|
||||
ropey = "1.6.1"
|
||||
serde = "1.0.190"
|
||||
serde_json = "1.0.108"
|
||||
# candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||
# candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||
# candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-core = { path = "../candle/candle-core", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-nn = { path = "../candle/candle-nn", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-transformers = { path = "../candle/candle-transformers", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-core = { path = "../candle/candle-core" }
|
||||
candle-nn = { path = "../candle/candle-nn" }
|
||||
candle-transformers = { path = "../candle/candle-transformers" }
|
||||
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
|
||||
rand = "0.8.5"
|
||||
tokenizers = "0.14.1"
|
||||
parking_lot = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
|
||||
2
run.sh
2
run.sh
@@ -1,3 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
/Users/silas/Projects/lsp-ai/target/release/lsp-ai
|
||||
/home/silas/Projects/lsp-ai/target/release/lsp-ai
|
||||
|
||||
216
src/main.rs
216
src/main.rs
@@ -4,26 +4,32 @@ use core::panic;
|
||||
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
|
||||
use lsp_types::{
|
||||
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
|
||||
CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, Position, Range,
|
||||
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
|
||||
CompletionParams, CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
|
||||
Position, Range, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use serde::Deserialize;
|
||||
// use pyo3::prelude::*;
|
||||
// use pyo3::types::PyTuple;
|
||||
use ropey::Rope;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
mod transformer;
|
||||
|
||||
static FILE_MAP: once_cell::sync::Lazy<Mutex<HashMap<String, Rope>>> =
|
||||
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
|
||||
mod models;
|
||||
use models::{Model, ModelParams};
|
||||
|
||||
// Taken directly from: https://github.com/rust-lang/rust-analyzer
|
||||
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
|
||||
notification.method == N::METHOD
|
||||
}
|
||||
|
||||
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
|
||||
where
|
||||
R: lsp_types::request::Request,
|
||||
R::Params: serde::de::DeserializeOwned,
|
||||
{
|
||||
req.extract(R::METHOD)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||
@@ -41,15 +47,107 @@ fn main() -> Result<()> {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Params {
|
||||
model: Option<String>,
|
||||
model_file: Option<String>,
|
||||
model_type: Option<String>,
|
||||
device: Option<String>,
|
||||
// We may want to put other non-model related parameters here in the future
|
||||
model_params: Option<ModelParams>,
|
||||
}
|
||||
|
||||
struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
// This main loop is tricky
|
||||
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
|
||||
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
|
||||
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
let params: Params = serde_json::from_value(params)?;
|
||||
let mut text_generation = transformer::build()?;
|
||||
|
||||
// Prep variables
|
||||
let connection = Arc::new(connection);
|
||||
let mut file_map: HashMap<String, Rope> = HashMap::new();
|
||||
|
||||
// How we communicate between the worker and receiver threads
|
||||
let last_completion_request = Arc::new(Mutex::new(None));
|
||||
|
||||
// Thread local variables
|
||||
let thread_last_completion_request = last_completion_request.clone();
|
||||
let thread_connection = connection.clone();
|
||||
// We need to allow unreachabel to be able to use the question mark operators here
|
||||
// We could probably restructure this to not require it
|
||||
#[allow(unreachable_code)]
|
||||
thread::spawn(move || {
|
||||
// Build the model from the params
|
||||
let mut model: Box<dyn Model> = params.model_params.unwrap_or_default().try_into()?;
|
||||
loop {
|
||||
// I think we need this drop, not 100% sure though
|
||||
let mut completion_request = thread_last_completion_request.lock();
|
||||
let params = std::mem::take(&mut *completion_request);
|
||||
drop(completion_request);
|
||||
if let Some(CompletionRequest {
|
||||
id,
|
||||
params,
|
||||
mut rope,
|
||||
}) = params
|
||||
{
|
||||
let filter_text = rope
|
||||
.get_line(params.text_document_position.position.line as usize)
|
||||
.context("Error getting line with ropey")?
|
||||
.to_string();
|
||||
|
||||
// Convert rope to correct prompt for llm
|
||||
let start_index = rope
|
||||
.line_to_char(params.text_document_position.position.line as usize)
|
||||
+ params.text_document_position.position.character as usize;
|
||||
rope.insert(start_index, "<fim_suffix>");
|
||||
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
|
||||
let insert_text = model.run(&prompt)?;
|
||||
|
||||
// Create and return the completion
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {insert_text}"),
|
||||
filter_text: Some(filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
let resp = Response {
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
};
|
||||
thread_connection.sender.send(Message::Response(resp))?;
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
anyhow::Ok(())
|
||||
});
|
||||
|
||||
for msg in &connection.receiver {
|
||||
match msg {
|
||||
Message::Request(req) => {
|
||||
@@ -59,59 +157,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
// Get rope
|
||||
let file_map = FILE_MAP.lock();
|
||||
let mut rope = file_map
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let filter_text = rope
|
||||
.get_line(params.text_document_position.position.line as usize)
|
||||
.context("Error getting line with ropey")?
|
||||
.to_string();
|
||||
|
||||
// Convert rope to correct prompt for llm
|
||||
let start_index = rope
|
||||
.line_to_char(params.text_document_position.position.line as usize)
|
||||
+ params.text_document_position.position.character as usize;
|
||||
rope.insert(start_index, "<fim_suffix>");
|
||||
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
|
||||
let insert_text = text_generation.run(&prompt, 64)?;
|
||||
|
||||
// Create and return the completion
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {insert_text}"),
|
||||
filter_text: Some(filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
|
||||
completion_text_edit,
|
||||
)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
let resp = Response {
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
};
|
||||
connection.sender.send(Message::Response(resp))?;
|
||||
// Update the last CompletionRequest
|
||||
let mut lcr = last_completion_request.lock();
|
||||
*lcr = Some(CompletionRequest::new(id, params, rope));
|
||||
continue;
|
||||
}
|
||||
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
|
||||
@@ -123,11 +175,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
let rope = Rope::from_str(¶ms.text_document.text);
|
||||
let mut file_map = FILE_MAP.lock();
|
||||
file_map.insert(params.text_document.uri.to_string(), rope);
|
||||
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
let mut file_map = FILE_MAP.lock();
|
||||
let rope = file_map
|
||||
.get_mut(params.text_document.uri.as_str())
|
||||
.context("Error trying to get file that does not exist")?;
|
||||
@@ -146,7 +196,6 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
}
|
||||
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||
let mut file_map = FILE_MAP.lock();
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||
file_map.insert(file_rename.new_uri, rope);
|
||||
@@ -159,36 +208,3 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
|
||||
where
|
||||
R: lsp_types::request::Request,
|
||||
R::Params: serde::de::DeserializeOwned,
|
||||
{
|
||||
req.extract(R::METHOD)
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
|
||||
// #[test]
|
||||
// fn test_lsp() -> Result<()> {
|
||||
// let prompt = "def sum_two_numers(x: int, y:";
|
||||
// let result = Python::with_gil(|py| -> Result<String> {
|
||||
// let transform: Py<PyAny> = PY_MODULE
|
||||
// .as_ref()
|
||||
// .expect("Error getting python module")
|
||||
// .getattr(py, "transform")
|
||||
// .expect("Error getting transform");
|
||||
|
||||
// let output = transform
|
||||
// .call1(py, PyTuple::new(py, &[prompt]))
|
||||
// .expect("Error calling transform");
|
||||
|
||||
// Ok(output.extract(py).expect("Error extracting result"))
|
||||
// })?;
|
||||
// println!("\n\nTHE RESULT\n{:?}\n\n", result);
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
|
||||
32
src/models/mod.rs
Normal file
32
src/models/mod.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use anyhow::Result;
|
||||
use serde::Deserialize;
|
||||
|
||||
mod starcoder;
|
||||
|
||||
pub trait Model {
|
||||
fn run(&mut self, prompt: &str) -> Result<String>;
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
pub struct ModelParams {
|
||||
model: Option<String>,
|
||||
model_file: Option<String>,
|
||||
model_type: Option<String>,
|
||||
max_length: Option<usize>,
|
||||
}
|
||||
|
||||
impl TryFrom<ModelParams> for Box<dyn Model> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: ModelParams) -> Result<Self> {
|
||||
let model_type = value.model_type.unwrap_or("starcoder".to_string());
|
||||
let max_length = value.max_length.unwrap_or(12);
|
||||
Ok(Box::new(match model_type.as_str() {
|
||||
"starcoder" => starcoder::build_model(value.model, value.model_file, max_length)?,
|
||||
_ => anyhow::bail!(
|
||||
"Model type: {} not supported. Feel free to make a pr or create a github issue.",
|
||||
model_type
|
||||
),
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -6,32 +6,16 @@ use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub struct TextGeneration {
|
||||
pub struct Model {
|
||||
model: GPTBigCode,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
max_length: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: GPTBigCode,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
|
||||
impl super::Model for Model {
|
||||
fn run(&mut self, prompt: &str) -> Result<String> {
|
||||
eprintln!("Starting to generate tokens");
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
@@ -42,7 +26,7 @@ impl TextGeneration {
|
||||
let mut new_tokens = vec![];
|
||||
let mut outputs = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
for index in 0..self.max_length {
|
||||
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
|
||||
(1, tokens.len().saturating_sub(1))
|
||||
} else {
|
||||
@@ -62,15 +46,40 @@ impl TextGeneration {
|
||||
let dt = start_gen.elapsed();
|
||||
self.model.clear_cache();
|
||||
eprintln!(
|
||||
"GENERATED {} tokens in {} seconds",
|
||||
"GENERATED {} tokens in {} milliseconds",
|
||||
outputs.len(),
|
||||
dt.as_secs()
|
||||
dt.as_millis()
|
||||
);
|
||||
Ok(outputs.join(""))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build() -> Result<TextGeneration> {
|
||||
impl Model {
|
||||
fn new(
|
||||
model: GPTBigCode,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
device: &Device,
|
||||
max_length: usize,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
max_length,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_model(
|
||||
model: Option<String>,
|
||||
model_file: Option<String>,
|
||||
max_length: usize,
|
||||
) -> Result<Model> {
|
||||
let start = std::time::Instant::now();
|
||||
eprintln!("Loading in model");
|
||||
let api = ApiBuilder::new()
|
||||
@@ -87,17 +96,24 @@ pub fn build() -> Result<TextGeneration> {
|
||||
.map(|f| repo.get(f))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
// Set the device
|
||||
#[cfg(feature = "cuda")]
|
||||
let device = Device::new_cuda(0)?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let device = Device::Cpu;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
eprintln!("loaded the model in {:?}", start.elapsed());
|
||||
Ok(TextGeneration::new(
|
||||
Ok(Model::new(
|
||||
model,
|
||||
tokenizer,
|
||||
0,
|
||||
Some(0.85),
|
||||
None,
|
||||
&device,
|
||||
max_length,
|
||||
))
|
||||
}
|
||||
Reference in New Issue
Block a user