From 31251e723594e9d99e74ce69f36c8e59c87d481c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sun, 12 May 2024 18:40:51 -0700 Subject: [PATCH] Compiling --- src/config.rs | 24 +-------- src/memory_backends/file_store.rs | 38 +++++---------- src/memory_backends/mod.rs | 25 ++++++---- src/memory_backends/postgresml/mod.rs | 14 +++--- src/memory_worker.rs | 18 +++---- src/transformer_backends/anthropic.rs | 19 +++----- src/transformer_backends/llama_cpp/mod.rs | 13 ++--- src/transformer_backends/mod.rs | 59 ++++++++++------------- src/transformer_backends/openai/mod.rs | 19 +++----- src/transformer_worker.rs | 11 ++--- 10 files changed, 95 insertions(+), 145 deletions(-) diff --git a/src/config.rs b/src/config.rs index ab3fd90..6db8a09 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,8 +3,6 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; -use crate::memory_backends::PromptForType; - const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024; const DEFAULT_OPENAI_MAX_CONTEXT_LENGTH: usize = 2048; @@ -187,7 +185,7 @@ pub struct Completion { // pub chat: Option>, // pub chat_template: Option, // pub chat_format: Option, - kwargs: HashMap, + pub kwargs: HashMap, } impl Default for Completion { @@ -262,26 +260,6 @@ impl Config { ValidModel::Anthropic(anthropic) => anthropic.max_requests_per_second, } } - - // pub fn get_completion_max_context_length(&self) -> anyhow::Result { - // Ok(self.config.completion.max_context_length) - // } - - // pub fn get_fim(&self) -> Option<&FIM> { - // match &self.config.transformer { - // ValidModel::LLaMACPP(llama_cpp) => llama_cpp.fim.as_ref(), - // ValidModel::OpenAI(openai) => openai.fim.as_ref(), - // ValidModel::Anthropic(_) => None, - // } - // } - - // pub fn get_chat(&self) -> Option<&Chat> { - // match &self.config.transformer { - // ValidModel::LLaMACPP(llama_cpp) => llama_cpp.chat.as_ref(), - // ValidModel::OpenAI(openai) => openai.chat.as_ref(), - // ValidModel::Anthropic(anthropic) => Some(&anthropic.chat), - // } - // } } #[cfg(test)] diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index f2e49b7..187d84c 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -3,6 +3,7 @@ use indexmap::IndexSet; use lsp_types::TextDocumentPositionParams; use parking_lot::Mutex; use ropey::Rope; +use serde_json::Value; use std::collections::HashMap; use tracing::instrument; @@ -11,7 +12,7 @@ use crate::{ utils::tokens_to_estimated_characters, }; -use super::{MemoryBackend, Prompt, PromptForType}; +use super::{MemoryBackend, MemoryRunParams, Prompt}; pub struct FileStore { _crawl: bool, @@ -104,33 +105,20 @@ impl FileStore { pub fn build_code( &self, position: &TextDocumentPositionParams, - prompt_for_type: PromptForType, - max_context_length: usize, + params: MemoryRunParams, ) -> anyhow::Result { - let (mut rope, cursor_index) = self.get_rope_for_position(position, max_context_length)?; + let (mut rope, cursor_index) = + self.get_rope_for_position(position, params.max_context_length)?; - let is_chat_enabled = match prompt_for_type { - PromptForType::Completion => self - .config - .get_chat() - .map(|c| c.completion.is_some()) - .unwrap_or(false), - PromptForType::Generate => self - .config - .get_chat() - .map(|c| c.generation.is_some()) - .unwrap_or(false), - }; - - Ok(match (is_chat_enabled, self.config.get_fim()) { + Ok(match (params.chat.is_some(), params.fim) { r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => { - let max_length = tokens_to_estimated_characters(max_context_length); + let max_length = tokens_to_estimated_characters(params.max_context_length); let start = cursor_index.saturating_sub(max_length / 2); let end = rope .len_chars() .min(cursor_index + (max_length - (cursor_index - start))); - if is_chat_enabled { + if r.0 { rope.insert(cursor_index, ""); let rope_slice = rope .get_slice(start..end + "".chars().count()) @@ -154,8 +142,8 @@ impl FileStore { } } _ => { - let start = - cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length)); + let start = cursor_index + .saturating_sub(tokens_to_estimated_characters(params.max_context_length)); let rope_slice = rope .get_slice(start..cursor_index) .context("Error getting rope slice")?; @@ -190,10 +178,10 @@ impl MemoryBackend for FileStore { async fn build_prompt( &self, position: &TextDocumentPositionParams, - max_context_length: usize, - prompt_for_type: PromptForType, + params: Value, ) -> anyhow::Result { - let code = self.build_code(position, prompt_for_type, max_context_length)?; + let params: MemoryRunParams = serde_json::from_value(params)?; + let code = self.build_code(position, params)?; Ok(Prompt::new("".to_string(), code)) } diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index d2ef5e5..62531f4 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -2,12 +2,26 @@ use lsp_types::{ DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, TextDocumentPositionParams, }; +use serde::Deserialize; +use serde_json::Value; -use crate::config::{Config, ValidMemoryBackend}; +use crate::config::{ChatMessage, Config, ValidMemoryBackend, FIM}; pub mod file_store; mod postgresml; +const fn max_context_length_default() -> usize { + 1024 +} + +#[derive(Clone, Deserialize)] +struct MemoryRunParams { + pub fim: Option, + pub chat: Option>, + #[serde(default = "max_context_length_default")] + pub max_context_length: usize, +} + #[derive(Debug)] pub struct Prompt { pub context: String, @@ -20,12 +34,6 @@ impl Prompt { } } -#[derive(Debug)] -pub enum PromptForType { - Completion, - Generate, -} - #[async_trait::async_trait] pub trait MemoryBackend { async fn init(&self) -> anyhow::Result<()> { @@ -40,8 +48,7 @@ pub trait MemoryBackend { async fn build_prompt( &self, position: &TextDocumentPositionParams, - max_context_length: usize, - prompt_for_type: PromptForType, + params: Value, ) -> anyhow::Result; async fn get_filter_text( &self, diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index 34e858a..608b111 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Context; use lsp_types::TextDocumentPositionParams; use pgml::{Collection, Pipeline}; -use serde_json::json; +use serde_json::{json, Value}; use tokio::time; use tracing::instrument; @@ -15,7 +15,7 @@ use crate::{ utils::tokens_to_estimated_characters, }; -use super::{file_store::FileStore, MemoryBackend, Prompt, PromptForType}; +use super::{file_store::FileStore, MemoryBackend, MemoryRunParams, Prompt}; pub struct PostgresML { configuration: Config, @@ -129,9 +129,9 @@ impl MemoryBackend for PostgresML { async fn build_prompt( &self, position: &TextDocumentPositionParams, - max_context_length: usize, - prompt_for_type: PromptForType, + params: Value, ) -> anyhow::Result { + let params: MemoryRunParams = serde_json::from_value(params)?; let query = self .file_store .get_characters_around_position(position, 512)?; @@ -162,8 +162,10 @@ impl MemoryBackend for PostgresML { }) .collect::>>()? .join("\n\n"); - let code = self.file_store.build_code(position, prompt_for_type, 512)?; - let max_characters = tokens_to_estimated_characters(max_context_length); + let mut file_store_params = params.clone(); + file_store_params.max_context_length = 512; + let code = self.file_store.build_code(position, file_store_params)?; + let max_characters = tokens_to_estimated_characters(params.max_context_length); let context: String = context .chars() .take(max_characters - code.chars().count()) diff --git a/src/memory_worker.rs b/src/memory_worker.rs index 70d086b..21b259b 100644 --- a/src/memory_worker.rs +++ b/src/memory_worker.rs @@ -4,29 +4,27 @@ use lsp_types::{ DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, TextDocumentPositionParams, }; +use serde_json::Value; use tracing::error; -use crate::memory_backends::{MemoryBackend, Prompt, PromptForType}; +use crate::memory_backends::{MemoryBackend, Prompt}; #[derive(Debug)] pub struct PromptRequest { position: TextDocumentPositionParams, - max_context_length: usize, - prompt_for_type: PromptForType, + params: Value, tx: tokio::sync::oneshot::Sender, } impl PromptRequest { pub fn new( position: TextDocumentPositionParams, - max_context_length: usize, - prompt_for_type: PromptForType, + params: Value, tx: tokio::sync::oneshot::Sender, ) -> Self { Self { position, - max_context_length, - prompt_for_type, + params, tx, } } @@ -69,11 +67,7 @@ async fn do_task( } WorkerRequest::Prompt(params) => { let prompt = memory_backend - .build_prompt( - ¶ms.position, - params.max_context_length, - params.prompt_for_type, - ) + .build_prompt(¶ms.position, params.params) .await?; params .tx diff --git a/src/transformer_backends/anthropic.rs b/src/transformer_backends/anthropic.rs index 73d07b6..c48ed41 100644 --- a/src/transformer_backends/anthropic.rs +++ b/src/transformer_backends/anthropic.rs @@ -13,7 +13,7 @@ use crate::{ utils::format_chat_messages, }; -use super::{RunParams, TransformerBackend}; +use super::TransformerBackend; const fn max_tokens_default() -> usize { 64 @@ -27,10 +27,6 @@ const fn temperature_default() -> f32 { 0.1 } -const fn max_context_length_default() -> usize { - 1024 -} - #[derive(Debug, Deserialize)] pub struct AnthropicRunParams { chat: Vec, @@ -40,8 +36,6 @@ pub struct AnthropicRunParams { pub top_p: f32, #[serde(default = "temperature_default")] pub temperature: f32, - #[serde(default = "max_context_length_default")] - max_context_length: usize, } pub struct Anthropic { @@ -133,9 +127,10 @@ impl TransformerBackend for Anthropic { async fn do_completion( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result { - let params: AnthropicRunParams = params.try_into()?; + // let params: AnthropicRunParams = params.try_into()?; + let params: AnthropicRunParams = serde_json::from_value(params)?; let insert_text = self.do_get_chat(prompt, params).await?; Ok(DoCompletionResponse { insert_text }) } @@ -144,9 +139,9 @@ impl TransformerBackend for Anthropic { async fn do_generate( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result { - let params: AnthropicRunParams = params.try_into()?; + let params: AnthropicRunParams = serde_json::from_value(params)?; let generated_text = self.do_get_chat(prompt, params).await?; Ok(DoGenerationResponse { generated_text }) } @@ -155,7 +150,7 @@ impl TransformerBackend for Anthropic { async fn do_generate_stream( &self, request: &GenerationStreamRequest, - params: RunParams, + params: Value, ) -> anyhow::Result { unimplemented!() } diff --git a/src/transformer_backends/llama_cpp/mod.rs b/src/transformer_backends/llama_cpp/mod.rs index 866a92b..85f9b83 100644 --- a/src/transformer_backends/llama_cpp/mod.rs +++ b/src/transformer_backends/llama_cpp/mod.rs @@ -1,6 +1,7 @@ use anyhow::Context; use hf_hub::api::sync::ApiBuilder; use serde::Deserialize; +use serde_json::Value; use tracing::instrument; use crate::{ @@ -17,7 +18,7 @@ use crate::{ mod model; use model::Model; -use super::{RunParams, TransformerBackend}; +use super::TransformerBackend; const fn max_new_tokens_default() -> usize { 32 @@ -86,9 +87,9 @@ impl TransformerBackend for LLaMACPP { async fn do_completion( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result { - let params: LLaMACPPRunParams = params.try_into()?; + let params: LLaMACPPRunParams = serde_json::from_value(params)?; let prompt = self.get_prompt_string(prompt, ¶ms)?; self.model .complete(&prompt, params) @@ -99,9 +100,9 @@ impl TransformerBackend for LLaMACPP { async fn do_generate( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result { - let params: LLaMACPPRunParams = params.try_into()?; + let params: LLaMACPPRunParams = serde_json::from_value(params)?; let prompt = self.get_prompt_string(prompt, ¶ms)?; self.model .complete(&prompt, params) @@ -112,7 +113,7 @@ impl TransformerBackend for LLaMACPP { async fn do_generate_stream( &self, _request: &GenerationStreamRequest, - params: RunParams, + params: Value, ) -> anyhow::Result { unimplemented!() } diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index d840390..b9d1dbc 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -1,3 +1,5 @@ +use serde_json::Value; + use crate::{ config::{self, ValidModel}, memory_backends::Prompt, @@ -13,56 +15,47 @@ mod anthropic; mod llama_cpp; mod openai; -#[derive(Debug)] -pub enum RunParams { - LLaMACPP(llama_cpp::LLaMACPPRunParams), - Anthropic(anthropic::AnthropicRunParams), - OpenAI(openai::OpenAIRunParams), -} +// impl RunParams { +// pub fn from_completion(completion: &Completion) -> Self { +// todo!() +// } +// } -impl RunParams { - pub fn from_completion(completion: &Completion) -> Self { - todo!() - } -} +// macro_rules! impl_runparams_try_into { +// ( $f:ident, $t:ident ) => { +// impl TryInto<$f> for RunParams { +// type Error = anyhow::Error; -macro_rules! impl_runparams_try_into { - ( $f:ident, $t:ident ) => { - impl TryInto<$f> for RunParams { - type Error = anyhow::Error; +// fn try_into(self) -> Result<$f, Self::Error> { +// match self { +// Self::$t(a) => Ok(a), +// _ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)), +// } +// } +// } +// }; +// } - fn try_into(self) -> Result<$f, Self::Error> { - match self { - Self::$t(a) => Ok(a), - _ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)), - } - } - } - }; -} - -impl_runparams_try_into!(AnthropicRunParams, Anthropic); -impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP); -impl_runparams_try_into!(OpenAIRunParams, OpenAI); +// impl_runparams_try_into!(AnthropicRunParams, Anthropic); +// impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP); +// impl_runparams_try_into!(OpenAIRunParams, OpenAI); #[async_trait::async_trait] pub trait TransformerBackend { - type Test = LLaMACPPRunParams; - async fn do_completion( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result; async fn do_generate( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result; async fn do_generate_stream( &self, request: &GenerationStreamRequest, - params: RunParams, + params: Value, ) -> anyhow::Result; } diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index a092567..2f31450 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -13,7 +13,7 @@ use crate::{ utils::{format_chat_messages, format_context_code}, }; -use super::{RunParams, TransformerBackend}; +use super::TransformerBackend; const fn max_tokens_default() -> usize { 64 @@ -35,10 +35,6 @@ const fn temperature_default() -> f32 { 0.1 } -const fn max_context_length_default() -> usize { - 1024 -} - #[derive(Debug, Deserialize)] pub struct OpenAIRunParams { pub fim: Option, @@ -53,8 +49,6 @@ pub struct OpenAIRunParams { pub frequency_penalty: f32, #[serde(default = "temperature_default")] pub temperature: f32, - #[serde(default = "max_context_length_default")] - max_context_length: usize, } pub struct OpenAI { @@ -202,9 +196,10 @@ impl TransformerBackend for OpenAI { async fn do_completion( &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result { - let params: OpenAIRunParams = params.try_into()?; + // let params: OpenAIRunParams = params.try_into()?; + let params: OpenAIRunParams = serde_json::from_value(params)?; let insert_text = self.do_chat_completion(prompt, params).await?; Ok(DoCompletionResponse { insert_text }) } @@ -214,9 +209,9 @@ impl TransformerBackend for OpenAI { &self, prompt: &Prompt, - params: RunParams, + params: Value, ) -> anyhow::Result { - let params: OpenAIRunParams = params.try_into()?; + let params: OpenAIRunParams = serde_json::from_value(params)?; let generated_text = self.do_chat_completion(prompt, params).await?; Ok(DoGenerationResponse { generated_text }) } @@ -225,7 +220,7 @@ impl TransformerBackend for OpenAI { async fn do_generate_stream( &self, request: &GenerationStreamRequest, - params: RunParams, + params: Value, ) -> anyhow::Result { unimplemented!() } diff --git a/src/transformer_worker.rs b/src/transformer_worker.rs index b25e303..6e88fab 100644 --- a/src/transformer_worker.rs +++ b/src/transformer_worker.rs @@ -14,7 +14,6 @@ use tracing::{error, instrument}; use crate::config::Config; use crate::custom_requests::generation::{GenerateResult, GenerationParams}; use crate::custom_requests::generation_stream::GenerationStreamParams; -use crate::memory_backends::PromptForType; use crate::memory_worker::{self, FilterRequest, PromptRequest}; use crate::transformer_backends::TransformerBackend; use crate::utils::ToResponseError; @@ -219,15 +218,13 @@ async fn do_completion( // TODO: Fix this // we need to be subtracting the completion / generation tokens from max_context_length // not sure if we should be doing that for the chat maybe leave a note here for that? - // let max_context_length = config.get_completion_max_context_length()?; - let params: - + let params = serde_json::to_value(config.config.completion.kwargs.clone()).unwrap(); + let (tx, rx) = oneshot::channel(); memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new( request.params.text_document_position.clone(), - max_context_length, - PromptForType::Completion, + params.clone(), tx, )))?; let prompt = rx.await?; @@ -238,7 +235,7 @@ async fn do_completion( ))?; let filter_text = rx.await?; - let response = transformer_backend.do_completion(&prompt).await?; + let response = transformer_backend.do_completion(&prompt, params).await?; let completion_text_edit = TextEdit::new( Range::new( Position::new(