mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 23:14:28 +01:00
Working
This commit is contained in:
@@ -18,7 +18,7 @@ tokenizers = "0.14.1"
|
||||
parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
llama-cpp-2 = "0.1.47"
|
||||
llama-cpp-2 = { version = "0.1.47", optional = true }
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
@@ -32,6 +32,7 @@ async-trait = "0.1.78"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
llamacpp = ["dep:llama-cpp-2"]
|
||||
cublas = ["llama-cpp-2/cublas"]
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
147
src/config.rs
147
src/config.rs
@@ -3,12 +3,6 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
||||
const DEFAULT_OPENAI_MAX_CONTEXT_LENGTH: usize = 2048;
|
||||
|
||||
const DEFAULT_MAX_COMPLETION_TOKENS: usize = 16;
|
||||
const DEFAULT_MAX_GENERATION_TOKENS: usize = 64;
|
||||
|
||||
pub type Kwargs = HashMap<String, Value>;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -19,15 +13,10 @@ pub enum ValidMemoryBackend {
|
||||
PostgresML(PostgresML),
|
||||
}
|
||||
|
||||
impl Default for ValidMemoryBackend {
|
||||
fn default() -> Self {
|
||||
ValidMemoryBackend::FileStore(FileStore::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ValidModel {
|
||||
#[cfg(feature = "llamacpp")]
|
||||
#[serde(rename = "llamacpp")]
|
||||
LLaMACPP(LLaMACPP),
|
||||
#[serde(rename = "openai")]
|
||||
@@ -36,12 +25,6 @@ pub enum ValidModel {
|
||||
Anthropic(Anthropic),
|
||||
}
|
||||
|
||||
impl Default for ValidModel {
|
||||
fn default() -> Self {
|
||||
ValidModel::LLaMACPP(LLaMACPP::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
@@ -64,21 +47,6 @@ pub struct FIM {
|
||||
pub end: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct MaxTokens {
|
||||
pub completion: usize,
|
||||
pub generation: usize,
|
||||
}
|
||||
|
||||
impl Default for MaxTokens {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
completion: DEFAULT_MAX_COMPLETION_TOKENS,
|
||||
generation: DEFAULT_MAX_GENERATION_TOKENS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct PostgresML {
|
||||
pub database_url: Option<String>,
|
||||
@@ -98,35 +66,23 @@ pub struct Model {
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
const fn n_gpu_layers_default() -> u32 {
|
||||
1000
|
||||
}
|
||||
|
||||
const fn n_ctx_default() -> u32 {
|
||||
1000
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct LLaMACPP {
|
||||
// The model to use
|
||||
#[serde(flatten)]
|
||||
pub model: Model,
|
||||
|
||||
// TODO: Remove Kwargs here and replace with concrete types
|
||||
// Kwargs passed to LlamaCPP
|
||||
#[serde(flatten)]
|
||||
pub kwargs: Kwargs,
|
||||
}
|
||||
|
||||
impl Default for LLaMACPP {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: Model {
|
||||
repository: "stabilityai/stable-code-3b".to_string(),
|
||||
name: Some("stable-code-3b-Q5_K_M.gguf".to_string()),
|
||||
},
|
||||
// fim: Some(FIM {
|
||||
// start: "<fim_prefix>".to_string(),
|
||||
// middle: "<fim_suffix>".to_string(),
|
||||
// end: "<fim_middle>".to_string(),
|
||||
// }),
|
||||
// max_tokens: MaxTokens::default(),
|
||||
// chat: None,
|
||||
kwargs: Kwargs::default(),
|
||||
}
|
||||
}
|
||||
#[serde(default = "n_gpu_layers_default")]
|
||||
pub n_gpu_layers: u32,
|
||||
#[serde(default = "n_ctx_default")]
|
||||
pub n_ctx: u32,
|
||||
}
|
||||
|
||||
const fn api_max_requests_per_second_default() -> f32 {
|
||||
@@ -170,46 +126,17 @@ pub struct Completion {
|
||||
// The model key to use
|
||||
pub model: String,
|
||||
|
||||
// // Model args
|
||||
// pub max_new_tokens: Option<usize>,
|
||||
// pub presence_penalty: Option<f32>,
|
||||
// pub frequency_penalty: Option<f32>,
|
||||
// pub top_p: Option<f32>,
|
||||
// pub temperature: Option<f32>,
|
||||
// pub max_context_length: Option<usize>,
|
||||
|
||||
// // FIM args
|
||||
// pub fim: Option<FIM>,
|
||||
|
||||
// // Chat args
|
||||
// pub chat: Option<Vec<ChatMessage>>,
|
||||
// pub chat_template: Option<String>,
|
||||
// pub chat_format: Option<String>,
|
||||
pub kwargs: HashMap<String, Value>,
|
||||
// Args are deserialized by the backend using them
|
||||
#[serde(flatten)]
|
||||
#[serde(default)]
|
||||
pub kwargs: Kwargs,
|
||||
}
|
||||
|
||||
impl Default for Completion {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: "default_model".to_string(),
|
||||
// fim: Some(FIM {
|
||||
// start: "<fim_prefix>".to_string(),
|
||||
// middle: "<fim_suffix>".to_string(),
|
||||
// end: "<fim_middle>".to_string(),
|
||||
// }),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct ValidConfig {
|
||||
#[serde(default)]
|
||||
pub memory: ValidMemoryBackend,
|
||||
#[serde(default)]
|
||||
pub models: HashMap<String, ValidModel>,
|
||||
#[serde(default)]
|
||||
pub completion: Completion,
|
||||
pub completion: Option<Completion>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
@@ -234,7 +161,7 @@ impl Config {
|
||||
.remove("initializationOptions");
|
||||
let valid_args = match configuration_args {
|
||||
Some(configuration_args) => serde_json::from_value(configuration_args)?,
|
||||
None => ValidConfig::default(),
|
||||
None => anyhow::bail!("lsp-ai does not currently provide a default configuration. Please pass a configuration. See https://github.com/SilasMarvin/lsp-ai for configuration options and examples"),
|
||||
};
|
||||
let client_params: ValidClientParams = serde_json::from_value(args)?;
|
||||
Ok(Self {
|
||||
@@ -247,17 +174,32 @@ impl Config {
|
||||
// Helpers for the backends ///////////
|
||||
///////////////////////////////////////
|
||||
|
||||
pub fn get_completion_transformer_max_requests_per_second(&self) -> f32 {
|
||||
// We can unwrap here as we verified this exists in the new function
|
||||
pub fn is_completions_enabled(&self) -> bool {
|
||||
self.config.completion.is_some()
|
||||
}
|
||||
|
||||
pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result<f32> {
|
||||
match &self
|
||||
.config
|
||||
.models
|
||||
.get(&self.config.completion.model)
|
||||
.unwrap()
|
||||
{
|
||||
ValidModel::LLaMACPP(_) => 1.,
|
||||
ValidModel::OpenAI(openai) => openai.max_requests_per_second,
|
||||
ValidModel::Anthropic(anthropic) => anthropic.max_requests_per_second,
|
||||
.get(
|
||||
&self
|
||||
.config
|
||||
.completion
|
||||
.as_ref()
|
||||
.context("Completions is not enabled")?
|
||||
.model,
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"`{}` model not found in `models` config",
|
||||
&self.config.completion.as_ref().unwrap().model
|
||||
)
|
||||
})? {
|
||||
#[cfg(feature = "llamacpp")]
|
||||
ValidModel::LLaMACPP(_) => Ok(1.),
|
||||
ValidModel::OpenAI(openai) => Ok(openai.max_requests_per_second),
|
||||
ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -290,6 +232,7 @@ mod test {
|
||||
"middle": "<fim_suffix>",
|
||||
"end": "<fim_middle>"
|
||||
},
|
||||
"max_context": 1024,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
}
|
||||
@@ -314,7 +257,7 @@ mod test {
|
||||
},
|
||||
"completion": {
|
||||
"model": "model1",
|
||||
"chat": [
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{CONTEXT}",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
pub enum Generation {}
|
||||
|
||||
@@ -10,6 +11,8 @@ pub struct GenerationParams {
|
||||
#[serde(flatten)]
|
||||
pub text_document_position: TextDocumentPositionParams,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
||||
|
||||
@@ -17,6 +17,7 @@ mod config;
|
||||
mod custom_requests;
|
||||
mod memory_backends;
|
||||
mod memory_worker;
|
||||
#[cfg(feature = "llamacpp")]
|
||||
mod template;
|
||||
mod transformer_backends;
|
||||
mod transformer_worker;
|
||||
|
||||
@@ -16,7 +16,7 @@ use super::{MemoryBackend, MemoryRunParams, Prompt};
|
||||
|
||||
pub struct FileStore {
|
||||
_crawl: bool,
|
||||
config: Config,
|
||||
_config: Config,
|
||||
file_map: Mutex<HashMap<String, Rope>>,
|
||||
accessed_files: Mutex<IndexSet<String>>,
|
||||
}
|
||||
@@ -25,7 +25,7 @@ impl FileStore {
|
||||
pub fn new(file_store_config: config::FileStore, config: Config) -> Self {
|
||||
Self {
|
||||
_crawl: file_store_config.crawl,
|
||||
config,
|
||||
_config: config,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
}
|
||||
@@ -34,7 +34,7 @@ impl FileStore {
|
||||
pub fn new_without_crawl(config: Config) -> Self {
|
||||
Self {
|
||||
_crawl: false,
|
||||
config,
|
||||
_config: config,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
}
|
||||
@@ -110,7 +110,7 @@ impl FileStore {
|
||||
let (mut rope, cursor_index) =
|
||||
self.get_rope_for_position(position, params.max_context_length)?;
|
||||
|
||||
Ok(match (params.chat.is_some(), params.fim) {
|
||||
Ok(match (params.messages.is_some(), params.fim) {
|
||||
r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => {
|
||||
let max_length = tokens_to_estimated_characters(params.max_context_length);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
|
||||
@@ -15,9 +15,9 @@ const fn max_context_length_default() -> usize {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct MemoryRunParams {
|
||||
pub struct MemoryRunParams {
|
||||
pub fim: Option<FIM>,
|
||||
pub chat: Option<Vec<ChatMessage>>,
|
||||
pub messages: Option<Vec<ChatMessage>>,
|
||||
#[serde(default = "max_context_length_default")]
|
||||
pub max_context_length: usize,
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::{
|
||||
use super::{file_store::FileStore, MemoryBackend, MemoryRunParams, Prompt};
|
||||
|
||||
pub struct PostgresML {
|
||||
configuration: Config,
|
||||
_config: Config,
|
||||
file_store: FileStore,
|
||||
collection: Collection,
|
||||
pipeline: Pipeline,
|
||||
@@ -105,7 +105,7 @@ impl PostgresML {
|
||||
}
|
||||
});
|
||||
Ok(Self {
|
||||
configuration,
|
||||
_config: configuration,
|
||||
file_store,
|
||||
collection,
|
||||
pipeline,
|
||||
|
||||
@@ -129,7 +129,6 @@ impl TransformerBackend for Anthropic {
|
||||
prompt: &Prompt,
|
||||
params: Value,
|
||||
) -> anyhow::Result<DoCompletionResponse> {
|
||||
// let params: AnthropicRunParams = params.try_into()?;
|
||||
let params: AnthropicRunParams = serde_json::from_value(params)?;
|
||||
let insert_text = self.do_get_chat(prompt, params).await?;
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
@@ -150,75 +149,68 @@ impl TransformerBackend for Anthropic {
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerationStreamRequest,
|
||||
params: Value,
|
||||
_params: Value,
|
||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||
unimplemented!()
|
||||
anyhow::bail!("GenerationStream is not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use serde_json::{from_value, json};
|
||||
|
||||
#[tokio::test]
|
||||
async fn anthropic_chat_do_completion() -> anyhow::Result<()> {
|
||||
// let configuration: config::Anthropic = serde_json::from_value(json!({
|
||||
// "chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
// "model": "claude-3-haiku-20240307",
|
||||
// "auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
// "chat": {
|
||||
// "completion": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
// }
|
||||
// ],
|
||||
// },
|
||||
// "max_tokens": {
|
||||
// "completion": 16,
|
||||
// "generation": 64
|
||||
// },
|
||||
// "max_context": 4096
|
||||
// }))?;
|
||||
// let anthropic = Anthropic::new(configuration);
|
||||
// let prompt = Prompt::default_with_cursor();
|
||||
// let response = anthropic.do_completion(&prompt).await?;
|
||||
// assert!(!response.insert_text.is_empty());
|
||||
let configuration: config::Anthropic = from_value(json!({
|
||||
"chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
}))?;
|
||||
let anthropic = Anthropic::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"chat": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = anthropic.do_completion(&prompt, run_params).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn anthropic_chat_do_generate() -> anyhow::Result<()> {
|
||||
// let configuration: config::Anthropic = serde_json::from_value(json!({
|
||||
// "chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
// "model": "claude-3-haiku-20240307",
|
||||
// "auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
// "chat": {
|
||||
// "generation": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
// "max_tokens": {
|
||||
// "completion": 16,
|
||||
// "generation": 64
|
||||
// },
|
||||
// "max_context": 4096
|
||||
// }))?;
|
||||
// let anthropic = Anthropic::new(configuration);
|
||||
// let prompt = Prompt::default_with_cursor();
|
||||
// let response = anthropic.do_generate(&prompt).await?;
|
||||
// assert!(!response.generated_text.is_empty());
|
||||
let configuration: config::Anthropic = from_value(json!({
|
||||
"chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
}))?;
|
||||
let anthropic = Anthropic::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"chat": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = anthropic.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,9 +27,9 @@ const fn max_new_tokens_default() -> usize {
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLaMACPPRunParams {
|
||||
pub fim: Option<FIM>,
|
||||
chat: Option<Vec<ChatMessage>>,
|
||||
chat_template: Option<String>,
|
||||
chat_format: Option<String>,
|
||||
messages: Option<Vec<ChatMessage>>,
|
||||
chat_template: Option<String>, // A Jinja template
|
||||
chat_format: Option<String>, // The name of a template in llamacpp
|
||||
#[serde(default = "max_new_tokens_default")]
|
||||
pub max_new_tokens: usize,
|
||||
// TODO: Explore other arguments
|
||||
@@ -37,7 +37,6 @@ pub struct LLaMACPPRunParams {
|
||||
|
||||
pub struct LLaMACPP {
|
||||
model: Model,
|
||||
configuration: config::LLaMACPP,
|
||||
}
|
||||
|
||||
impl LLaMACPP {
|
||||
@@ -48,14 +47,11 @@ impl LLaMACPP {
|
||||
.model
|
||||
.name
|
||||
.as_ref()
|
||||
.context("Please set `transformer->llamacpp->name` to use LLaMA.cpp")?;
|
||||
.context("Please set `name` to use LLaMA.cpp")?;
|
||||
let repo = api.model(configuration.model.repository.to_owned());
|
||||
let model_path = repo.get(name)?;
|
||||
let model = Model::new(model_path, &configuration.kwargs)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
configuration,
|
||||
})
|
||||
let model = Model::new(model_path, &configuration)?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
@@ -64,8 +60,7 @@ impl LLaMACPP {
|
||||
prompt: &Prompt,
|
||||
params: &LLaMACPPRunParams,
|
||||
) -> anyhow::Result<String> {
|
||||
// We need to check that they not only set the `chat` key, but they set the `completion` sub key
|
||||
Ok(match ¶ms.chat {
|
||||
Ok(match ¶ms.messages {
|
||||
Some(completion_messages) => {
|
||||
let chat_messages = format_chat_messages(completion_messages, prompt);
|
||||
if let Some(chat_template) = ¶ms.chat_template {
|
||||
@@ -73,7 +68,8 @@ impl LLaMACPP {
|
||||
let eos_token = self.model.get_eos_token()?;
|
||||
apply_chat_template(chat_template, chat_messages, &bos_token, &eos_token)?
|
||||
} else {
|
||||
self.model.apply_chat_template(chat_messages, None)?
|
||||
self.model
|
||||
.apply_chat_template(chat_messages, params.chat_format.clone())?
|
||||
}
|
||||
}
|
||||
None => prompt.code.to_owned(),
|
||||
@@ -113,9 +109,9 @@ impl TransformerBackend for LLaMACPP {
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
_request: &GenerationStreamRequest,
|
||||
params: Value,
|
||||
_params: Value,
|
||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||
unimplemented!()
|
||||
anyhow::bail!("GenerationStream is not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,50 +120,71 @@ mod test {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// // "completion": [
|
||||
// // {
|
||||
// // "role": "system",
|
||||
// // "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||
// // },
|
||||
// // {
|
||||
// // "role": "user",
|
||||
// // "content": "Complete the following code: \n\n{code}"
|
||||
// // }
|
||||
// // ],
|
||||
// // "generation": [
|
||||
// // {
|
||||
// // "role": "system",
|
||||
// // "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
// // },
|
||||
// // {
|
||||
// // "role": "user",
|
||||
// // "content": "Complete the following code: \n\n{code}"
|
||||
// // }
|
||||
// // ],
|
||||
// "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||
|
||||
#[tokio::test]
|
||||
async fn llama_cpp_do_completion() -> anyhow::Result<()> {
|
||||
// let configuration: config::LLaMACPP = serde_json::from_value(json!({
|
||||
// "repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
|
||||
// "name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
|
||||
// "max_new_tokens": {
|
||||
// "completion": 32,
|
||||
// "generation": 256,
|
||||
// },
|
||||
// "fim": {
|
||||
// "start": "<fim_prefix>",
|
||||
// "middle": "<fim_suffix>",
|
||||
// "end": "<fim_middle>"
|
||||
// },
|
||||
// "chat": {
|
||||
// // "completion": [
|
||||
// // {
|
||||
// // "role": "system",
|
||||
// // "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
|
||||
// // },
|
||||
// // {
|
||||
// // "role": "user",
|
||||
// // "content": "Complete the following code: \n\n{code}"
|
||||
// // }
|
||||
// // ],
|
||||
// // "generation": [
|
||||
// // {
|
||||
// // "role": "system",
|
||||
// // "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
|
||||
// // },
|
||||
// // {
|
||||
// // "role": "user",
|
||||
// // "content": "Complete the following code: \n\n{code}"
|
||||
// // }
|
||||
// // ],
|
||||
// "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
|
||||
// },
|
||||
// "n_ctx": 2048,
|
||||
// "n_gpu_layers": 35,
|
||||
// }))?;
|
||||
// let llama_cpp = LLaMACPP::new(configuration).unwrap();
|
||||
// let prompt = Prompt::default_with_cursor();
|
||||
// let response = llama_cpp.do_completion(&prompt).await?;
|
||||
// assert!(!response.insert_text.is_empty());
|
||||
let configuration: config::LLaMACPP = serde_json::from_value(json!({
|
||||
"repository": "stabilityai/stable-code-3b",
|
||||
"name": "stable-code-3b-Q5_K_M.gguf",
|
||||
"n_ctx": 2048,
|
||||
"n_gpu_layers": 35,
|
||||
}))?;
|
||||
let llama_cpp = LLaMACPP::new(configuration).unwrap();
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"fim": {
|
||||
"start": "<fim_prefix>",
|
||||
"middle": "<fim_suffix>",
|
||||
"end": "<fim_middle>"
|
||||
},
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = llama_cpp.do_completion(&prompt, run_params).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn llama_cpp_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: config::LLaMACPP = serde_json::from_value(json!({
|
||||
"repository": "stabilityai/stable-code-3b",
|
||||
"name": "stable-code-3b-Q5_K_M.gguf",
|
||||
"n_ctx": 2048,
|
||||
"n_gpu_layers": 35,
|
||||
}))?;
|
||||
let llama_cpp = LLaMACPP::new(configuration).unwrap();
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"fim": {
|
||||
"start": "<fim_prefix>",
|
||||
"middle": "<fim_suffix>",
|
||||
"end": "<fim_middle>"
|
||||
},
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = llama_cpp.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use once_cell::sync::Lazy;
|
||||
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
|
||||
use tracing::{debug, info, instrument};
|
||||
|
||||
use crate::config::{ChatMessage, Kwargs};
|
||||
use crate::config::{self, ChatMessage};
|
||||
|
||||
use super::LLaMACPPRunParams;
|
||||
|
||||
@@ -24,34 +24,18 @@ pub struct Model {
|
||||
|
||||
impl Model {
|
||||
#[instrument]
|
||||
pub fn new(model_path: PathBuf, kwargs: &Kwargs) -> anyhow::Result<Self> {
|
||||
// Get n_gpu_layers if set in kwargs
|
||||
// As a default we set it to 1000, which should put all layers on the GPU
|
||||
let n_gpu_layers = kwargs
|
||||
.get("n_gpu_layers")
|
||||
.map(|u| anyhow::Ok(u.as_u64().context("n_gpu_layers must be a number")? as u32))
|
||||
.unwrap_or_else(|| Ok(1000))?;
|
||||
|
||||
pub fn new(model_path: PathBuf, config: &config::LLaMACPP) -> anyhow::Result<Self> {
|
||||
// Initialize the model_params
|
||||
let model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
|
||||
let model_params = LlamaModelParams::default().with_n_gpu_layers(config.n_gpu_layers);
|
||||
|
||||
// Load the model
|
||||
debug!("Loading model at path: {:?}", model_path);
|
||||
let model = LlamaModel::load_from_file(&BACKEND, model_path, &model_params)?;
|
||||
|
||||
// Get n_ctx if set in kwargs
|
||||
// As a default we set it to 2048
|
||||
let n_ctx = kwargs
|
||||
.get("n_ctx")
|
||||
.map(|u| {
|
||||
anyhow::Ok(NonZeroU32::new(
|
||||
u.as_u64().context("n_ctx must be a number")? as u32,
|
||||
))
|
||||
})
|
||||
.unwrap_or_else(|| Ok(NonZeroU32::new(2048)))?
|
||||
.context("n_ctx must not be zero")?;
|
||||
|
||||
Ok(Model { model, n_ctx })
|
||||
Ok(Model {
|
||||
model,
|
||||
n_ctx: NonZeroU32::new(config.n_ctx).context("`n_ctx` must be non zero")?,
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
config::{self, ValidModel},
|
||||
config::ValidModel,
|
||||
memory_backends::Prompt,
|
||||
transformer_worker::{
|
||||
DoCompletionResponse, DoGenerationResponse, DoGenerationStreamResponse,
|
||||
@@ -9,37 +9,11 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
use self::{anthropic::AnthropicRunParams, llama_cpp::LLaMACPPRunParams, openai::OpenAIRunParams};
|
||||
|
||||
mod anthropic;
|
||||
#[cfg(feature = "llamacpp")]
|
||||
mod llama_cpp;
|
||||
mod openai;
|
||||
|
||||
// impl RunParams {
|
||||
// pub fn from_completion(completion: &Completion) -> Self {
|
||||
// todo!()
|
||||
// }
|
||||
// }
|
||||
|
||||
// macro_rules! impl_runparams_try_into {
|
||||
// ( $f:ident, $t:ident ) => {
|
||||
// impl TryInto<$f> for RunParams {
|
||||
// type Error = anyhow::Error;
|
||||
|
||||
// fn try_into(self) -> Result<$f, Self::Error> {
|
||||
// match self {
|
||||
// Self::$t(a) => Ok(a),
|
||||
// _ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
|
||||
// impl_runparams_try_into!(AnthropicRunParams, Anthropic);
|
||||
// impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP);
|
||||
// impl_runparams_try_into!(OpenAIRunParams, OpenAI);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait TransformerBackend {
|
||||
async fn do_completion(
|
||||
@@ -64,6 +38,7 @@ impl TryFrom<ValidModel> for Box<dyn TransformerBackend + Send + Sync> {
|
||||
|
||||
fn try_from(valid_model: ValidModel) -> Result<Self, Self::Error> {
|
||||
match valid_model {
|
||||
#[cfg(feature = "llamacpp")]
|
||||
ValidModel::LLaMACPP(model_gguf) => Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?)),
|
||||
ValidModel::OpenAI(openai_config) => Ok(Box::new(openai::OpenAI::new(openai_config))),
|
||||
ValidModel::Anthropic(anthropic_config) => {
|
||||
|
||||
@@ -38,7 +38,7 @@ const fn temperature_default() -> f32 {
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OpenAIRunParams {
|
||||
pub fim: Option<FIM>,
|
||||
chat: Option<Vec<ChatMessage>>,
|
||||
messages: Option<Vec<ChatMessage>>,
|
||||
#[serde(default = "max_tokens_default")]
|
||||
pub max_tokens: usize,
|
||||
#[serde(default = "top_p_default")]
|
||||
@@ -89,7 +89,9 @@ impl OpenAI {
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API")
|
||||
anyhow::bail!(
|
||||
"set `auth_token_env_var_name` or `auth_token` to use an OpenAI compatible API"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,7 +107,7 @@ impl OpenAI {
|
||||
self.configuration
|
||||
.completions_endpoint
|
||||
.as_ref()
|
||||
.context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?,
|
||||
.context("specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `messages`.")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -168,7 +170,7 @@ impl OpenAI {
|
||||
} else if let Some(choices) = res.choices {
|
||||
Ok(choices[0].message.content.clone())
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
anyhow::bail!("Unknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +179,7 @@ impl OpenAI {
|
||||
prompt: &Prompt,
|
||||
params: OpenAIRunParams,
|
||||
) -> anyhow::Result<String> {
|
||||
match ¶ms.chat {
|
||||
match ¶ms.messages {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, params).await
|
||||
@@ -198,7 +200,6 @@ impl TransformerBackend for OpenAI {
|
||||
prompt: &Prompt,
|
||||
params: Value,
|
||||
) -> anyhow::Result<DoCompletionResponse> {
|
||||
// let params: OpenAIRunParams = params.try_into()?;
|
||||
let params: OpenAIRunParams = serde_json::from_value(params)?;
|
||||
let insert_text = self.do_chat_completion(prompt, params).await?;
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
@@ -220,113 +221,102 @@ impl TransformerBackend for OpenAI {
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerationStreamRequest,
|
||||
params: Value,
|
||||
_params: Value,
|
||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||
unimplemented!()
|
||||
anyhow::bail!("GenerationStream is not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use serde_json::{from_value, json};
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_completion_do_completion() -> anyhow::Result<()> {
|
||||
// let configuration: config::OpenAI = serde_json::from_value(json!({
|
||||
// "completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
// "model": "gpt-3.5-turbo-instruct",
|
||||
// "auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
// "max_tokens": {
|
||||
// "completion": 16,
|
||||
// "generation": 64
|
||||
// },
|
||||
// "max_context": 4096
|
||||
// }))?;
|
||||
// let openai = OpenAI::new(configuration);
|
||||
// let prompt = Prompt::default_without_cursor();
|
||||
// let response = openai.do_completion(&prompt).await?;
|
||||
// assert!(!response.insert_text.is_empty());
|
||||
let configuration: config::OpenAI = from_value(json!({
|
||||
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let run_params = json!({
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = openai.do_completion(&prompt, run_params).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_completion() -> anyhow::Result<()> {
|
||||
// let configuration: config::OpenAI = serde_json::from_value(json!({
|
||||
// "chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
// "model": "gpt-3.5-turbo",
|
||||
// "auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
// "chat": {
|
||||
// "completion": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a coding assistant. Your job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
// }
|
||||
// ],
|
||||
// },
|
||||
// "max_tokens": {
|
||||
// "completion": 16,
|
||||
// "generation": 64
|
||||
// },
|
||||
// "max_context": 4096
|
||||
// }))?;
|
||||
// let openai = OpenAI::new(configuration);
|
||||
// let prompt = Prompt::default_with_cursor();
|
||||
// let response = openai.do_completion(&prompt).await?;
|
||||
// assert!(!response.insert_text.is_empty());
|
||||
let configuration: config::OpenAI = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. Your job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = openai.do_completion(&prompt, run_params).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_completion_do_generate() -> anyhow::Result<()> {
|
||||
// let configuration: config::OpenAI = serde_json::from_value(json!({
|
||||
// "completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
// "model": "gpt-3.5-turbo-instruct",
|
||||
// "auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
// "max_tokens": {
|
||||
// "completion": 16,
|
||||
// "generation": 64
|
||||
// },
|
||||
// "max_context": 4096
|
||||
// }))?;
|
||||
// let openai = OpenAI::new(configuration);
|
||||
// let prompt = Prompt::default_without_cursor();
|
||||
// let response = openai.do_generate(&prompt).await?;
|
||||
// assert!(!response.generated_text.is_empty());
|
||||
let configuration: config::OpenAI = from_value(json!({
|
||||
"completions_endpoint": "https://api.openai.com/v1/completions",
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let run_params = json!({
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = openai.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_generate() -> anyhow::Result<()> {
|
||||
// let configuration: config::OpenAI = serde_json::from_value(json!({
|
||||
// "chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
// "model": "gpt-3.5-turbo",
|
||||
// "auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
// "chat": {
|
||||
// "generation": [
|
||||
// {
|
||||
// "role": "system",
|
||||
// "content": "You are a coding assistant. Your job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
// "max_tokens": {
|
||||
// "completion": 16,
|
||||
// "generation": 64
|
||||
// },
|
||||
// "max_context": 4096
|
||||
// }))?;
|
||||
// let openai = OpenAI::new(configuration);
|
||||
// let prompt = Prompt::default_with_cursor();
|
||||
// let response = openai.do_generate(&prompt).await?;
|
||||
// assert!(!response.generated_text.is_empty());
|
||||
let configuration: config::OpenAI = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. Your job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = openai.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +63,16 @@ pub enum WorkerRequest {
|
||||
GenerationStream(GenerationStreamRequest),
|
||||
}
|
||||
|
||||
impl WorkerRequest {
|
||||
fn get_id(&self) -> RequestId {
|
||||
match self {
|
||||
WorkerRequest::Completion(r) => r.id.clone(),
|
||||
WorkerRequest::Generation(r) => r.id.clone(),
|
||||
WorkerRequest::GenerationStream(r) => r.id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DoCompletionResponse {
|
||||
pub insert_text: String,
|
||||
}
|
||||
@@ -75,122 +85,6 @@ pub struct DoGenerationStreamResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
#[instrument(skip(config, transformer_backend, memory_backend_tx, connection))]
|
||||
async fn do_task(
|
||||
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
request: WorkerRequest,
|
||||
connection: Arc<Connection>,
|
||||
config: Config,
|
||||
) -> anyhow::Result<()> {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => {
|
||||
match do_completion(transformer_backend, memory_backend_tx, &request, &config).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
}
|
||||
}
|
||||
WorkerRequest::Generation(request) => {
|
||||
match do_generate(transformer_backend, memory_backend_tx, &request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
},
|
||||
}
|
||||
}
|
||||
WorkerRequest::GenerationStream(_) => {
|
||||
panic!("Streaming is not yet supported")
|
||||
}
|
||||
};
|
||||
connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending response");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn do_run(
|
||||
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
|
||||
connection: Arc<Connection>,
|
||||
config: Config,
|
||||
) -> anyhow::Result<()> {
|
||||
let transformer_backends: HashMap<String, Arc<Box<dyn TransformerBackend + Send + Sync>>> =
|
||||
transformer_backends
|
||||
.into_iter()
|
||||
.map(|(key, backend)| (key, Arc::new(backend)))
|
||||
.collect();
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let max_requests_per_second = config.get_completion_transformer_max_requests_per_second();
|
||||
let mut last_completion_request_time = SystemTime::now();
|
||||
let mut last_completion_request = None;
|
||||
|
||||
let dispatch_request = move |request| {
|
||||
let thread_transformer_backend = transformer_backends
|
||||
.get(&config.config.completion.model)
|
||||
.clone()
|
||||
.with_context(|| format!("can't find model: {}", config.config.completion.model))?
|
||||
.clone();
|
||||
let thread_config = config.clone();
|
||||
let thread_memory_backend_tx = memory_backend_tx.clone();
|
||||
let thread_connection = connection.clone();
|
||||
runtime.spawn(async move {
|
||||
if let Err(e) = do_task(
|
||||
thread_transformer_backend,
|
||||
thread_memory_backend_tx,
|
||||
request,
|
||||
thread_connection,
|
||||
thread_config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("transformer worker task: {e}")
|
||||
}
|
||||
});
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
loop {
|
||||
// We want to rate limit completions without dropping the last rate limited request
|
||||
let request = transformer_rx.recv_timeout(Duration::from_millis(5));
|
||||
|
||||
match request {
|
||||
Ok(request) => match request {
|
||||
WorkerRequest::Completion(_) => last_completion_request = Some(request),
|
||||
_ => {
|
||||
dispatch_request(request)?;
|
||||
}
|
||||
},
|
||||
Err(RecvTimeoutError::Disconnected) => anyhow::bail!("channel disconnected"),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if SystemTime::now()
|
||||
.duration_since(last_completion_request_time)?
|
||||
.as_secs_f32()
|
||||
< 1. / max_requests_per_second
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(request) = last_completion_request.take() {
|
||||
last_completion_request_time = SystemTime::now();
|
||||
dispatch_request(request)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
@@ -209,17 +103,151 @@ pub fn run(
|
||||
}
|
||||
}
|
||||
|
||||
fn do_run(
|
||||
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
transformer_rx: std::sync::mpsc::Receiver<WorkerRequest>,
|
||||
connection: Arc<Connection>,
|
||||
config: Config,
|
||||
) -> anyhow::Result<()> {
|
||||
let transformer_backends = Arc::new(transformer_backends);
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
// If they have disabled completions, this function will fail. We set it to MIN_POSITIVE to never process a completions request
|
||||
let max_requests_per_second = config
|
||||
.get_completion_transformer_max_requests_per_second()
|
||||
.unwrap_or(f32::MIN_POSITIVE);
|
||||
let mut last_completion_request_time = SystemTime::now();
|
||||
let mut last_completion_request = None;
|
||||
|
||||
let run_dispatch_request = |request| {
|
||||
let task_connection = connection.clone();
|
||||
let task_transformer_backends = transformer_backends.clone();
|
||||
let task_memory_backend_tx = memory_backend_tx.clone();
|
||||
let task_config = config.clone();
|
||||
runtime.spawn(async move {
|
||||
dispatch_request(
|
||||
request,
|
||||
task_connection,
|
||||
task_transformer_backends,
|
||||
task_memory_backend_tx,
|
||||
task_config,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
};
|
||||
|
||||
loop {
|
||||
// We want to rate limit completions without dropping the last rate limited request
|
||||
let request = transformer_rx.recv_timeout(Duration::from_millis(5));
|
||||
|
||||
match request {
|
||||
Ok(request) => match request {
|
||||
WorkerRequest::Completion(_) => last_completion_request = Some(request),
|
||||
_ => run_dispatch_request(request),
|
||||
},
|
||||
Err(RecvTimeoutError::Disconnected) => anyhow::bail!("channel disconnected"),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if SystemTime::now()
|
||||
.duration_since(last_completion_request_time)?
|
||||
.as_secs_f32()
|
||||
< 1. / max_requests_per_second
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(request) = last_completion_request.take() {
|
||||
last_completion_request_time = SystemTime::now();
|
||||
run_dispatch_request(request);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(connection, transformer_backends, memory_backend_tx, config))]
|
||||
async fn dispatch_request(
|
||||
request: WorkerRequest,
|
||||
connection: Arc<Connection>,
|
||||
transformer_backends: Arc<HashMap<String, Box<dyn TransformerBackend + Send + Sync>>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
config: Config,
|
||||
) {
|
||||
let response = match generate_response(
|
||||
request.clone(),
|
||||
transformer_backends,
|
||||
memory_backend_tx,
|
||||
config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
error!("generating response: {e}");
|
||||
Response {
|
||||
id: request.get_id(),
|
||||
result: None,
|
||||
error: Some(e.to_response_error(-32603)),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = connection.sender.send(Message::Response(response)) {
|
||||
error!("sending response: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_response(
|
||||
request: WorkerRequest,
|
||||
transformer_backends: Arc<HashMap<String, Box<dyn TransformerBackend + Send + Sync>>>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
config: Config,
|
||||
) -> anyhow::Result<Response> {
|
||||
match request {
|
||||
WorkerRequest::Completion(request) => {
|
||||
let completion_config = config
|
||||
.config
|
||||
.completion
|
||||
.as_ref()
|
||||
.context("Completions is none")?;
|
||||
let transformer_backend = transformer_backends
|
||||
.get(&completion_config.model)
|
||||
.clone()
|
||||
.with_context(|| format!("can't find model: {}", &completion_config.model))?;
|
||||
do_completion(transformer_backend, memory_backend_tx, &request, &config).await
|
||||
}
|
||||
WorkerRequest::Generation(request) => {
|
||||
let transformer_backend = transformer_backends
|
||||
.get(&request.params.model)
|
||||
.clone()
|
||||
.with_context(|| format!("can't find model: {}", &request.params.model))?;
|
||||
do_generate(transformer_backend, memory_backend_tx, &request).await
|
||||
}
|
||||
WorkerRequest::GenerationStream(_) => {
|
||||
anyhow::bail!("Streaming is not yet supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_completion(
|
||||
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
|
||||
transformer_backend: &Box<dyn TransformerBackend + Send + Sync>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
request: &CompletionRequest,
|
||||
config: &Config,
|
||||
) -> anyhow::Result<Response> {
|
||||
// TODO: Fix this
|
||||
// we need to be subtracting the completion / generation tokens from max_context_length
|
||||
// not sure if we should be doing that for the chat maybe leave a note here for that?
|
||||
|
||||
let params = serde_json::to_value(config.config.completion.kwargs.clone()).unwrap();
|
||||
let params = serde_json::to_value(
|
||||
config
|
||||
.config
|
||||
.completion
|
||||
.as_ref()
|
||||
.context("Completions is None")?
|
||||
.kwargs
|
||||
.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||
@@ -270,27 +298,28 @@ async fn do_completion(
|
||||
}
|
||||
|
||||
async fn do_generate(
|
||||
transformer_backend: Arc<Box<dyn TransformerBackend + Send + Sync>>,
|
||||
transformer_backend: &Box<dyn TransformerBackend + Send + Sync>,
|
||||
memory_backend_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
request: &GenerationRequest,
|
||||
) -> anyhow::Result<Response> {
|
||||
todo!()
|
||||
// let (tx, rx) = oneshot::channel();
|
||||
// memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||
// request.params.text_document_position.clone(),
|
||||
// PromptForType::Completion,
|
||||
// tx,
|
||||
// )))?;
|
||||
// let prompt = rx.await?;
|
||||
let params = serde_json::to_value(request.params.parameters.clone()).unwrap();
|
||||
|
||||
// let response = transformer_backend.do_generate(&prompt).await?;
|
||||
// let result = GenerateResult {
|
||||
// generated_text: response.generated_text,
|
||||
// };
|
||||
// let result = serde_json::to_value(result).unwrap();
|
||||
// Ok(Response {
|
||||
// id: request.id.clone(),
|
||||
// result: Some(result),
|
||||
// error: None,
|
||||
// })
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||
request.params.text_document_position.clone(),
|
||||
params.clone(),
|
||||
tx,
|
||||
)))?;
|
||||
let prompt = rx.await?;
|
||||
|
||||
let response = transformer_backend.do_generate(&prompt, params).await?;
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
Ok(Response {
|
||||
id: request.id.clone(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user