From fa8e19c1ce3d95c1f2a65e40b4c164da348d7557 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 23 Mar 2024 19:01:05 -0700 Subject: [PATCH] Overhaul done --- src/configuration.rs | 31 +++- src/main.rs | 4 - src/memory_backends/mod.rs | 7 + src/transformer_backends/anthropic.rs | 207 +++++++++++++++++++++++++ src/transformer_backends/mod.rs | 4 + src/transformer_backends/openai/mod.rs | 128 ++++++++------- 6 files changed, 316 insertions(+), 65 deletions(-) create mode 100644 src/transformer_backends/anthropic.rs diff --git a/src/configuration.rs b/src/configuration.rs index c51f0bf..39f076b 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -19,6 +19,7 @@ pub enum ValidMemoryBackend { pub enum ValidTransformerBackend { LlamaCPP(ModelGGUF), OpenAI(OpenAI), + Anthropic(Anthropic), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -36,6 +37,7 @@ pub struct Chat { } #[derive(Clone, Debug, Deserialize)] +#[allow(clippy::upper_case_acronyms)] pub struct FIM { pub start: String, pub middle: String, @@ -145,7 +147,7 @@ const fn openai_max_context() -> usize { DEFAULT_OPENAI_MAX_CONTEXT } -#[derive(Clone, Debug, Default, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OpenAI { // The auth token env var name pub auth_token_env_var_name: Option, @@ -176,9 +178,35 @@ pub struct OpenAI { max_context: usize, } +#[derive(Clone, Debug, Deserialize)] +pub struct Anthropic { + // The auth token env var name + pub auth_token_env_var_name: Option, + pub auth_token: Option, + // The completions endpoint + pub completions_endpoint: Option, + // The chat endpoint + pub chat_endpoint: Option, + // The model name + pub model: String, + // Fill in the middle support + pub fim: Option, + // The maximum number of new tokens to generate + #[serde(default)] + pub max_tokens: MaxTokens, + // Chat args + pub chat: Chat, + // System prompt + #[serde(default = "openai_top_p_default")] + pub top_p: f32, + #[serde(default = "openai_temperature")] + pub temperature: f32, +} + #[derive(Clone, Debug, Deserialize)] struct ValidTransformerConfiguration { openai: Option, + anthropic: Option, model_gguf: Option, } @@ -186,6 +214,7 @@ impl Default for ValidTransformerConfiguration { fn default() -> Self { Self { model_gguf: Some(ModelGGUF::default()), + anthropic: None, openai: None, } } diff --git a/src/main.rs b/src/main.rs index 2569aea..c32a324 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,10 +72,6 @@ fn main() -> Result<()> { Ok(()) } -// This main loop is tricky -// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get -// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request -// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { // Build our configuration let configuration = Configuration::new(args)?; diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index 58e39f1..1a0832d 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -73,4 +73,11 @@ impl Prompt { code: r#"def test_code():\n "#.to_string(), } } + + pub fn default_without_cursor() -> Self { + Self { + context: r#"def test_context():\n pass"#.to_string(), + code: r#"def test_code():\n "#.to_string(), + } + } } diff --git a/src/transformer_backends/anthropic.rs b/src/transformer_backends/anthropic.rs new file mode 100644 index 0000000..9d2bb9c --- /dev/null +++ b/src/transformer_backends/anthropic.rs @@ -0,0 +1,207 @@ +use anyhow::Context; +use serde::Deserialize; +use serde_json::{json, Value}; +use tracing::instrument; + +use crate::{ + configuration::{self, ChatMessage}, + memory_backends::Prompt, + transformer_worker::{ + DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest, + }, + utils::format_chat_messages, +}; + +use super::TransformerBackend; + +pub struct Anthropic { + configuration: configuration::Anthropic, +} + +#[derive(Deserialize)] +struct AnthropicChatMessage { + text: String, +} + +#[derive(Deserialize)] +struct AnthropicChatResponse { + content: Option>, + error: Option, +} + +impl Anthropic { + #[instrument] + pub fn new(configuration: configuration::Anthropic) -> Self { + Self { configuration } + } + + async fn get_chat( + &self, + system_prompt: String, + messages: Vec, + max_tokens: usize, + ) -> anyhow::Result { + eprintln!( + "SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******", + messages + ); + let client = reqwest::Client::new(); + let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + std::env::var(env_var_name)? + } else if let Some(token) = &self.configuration.auth_token { + token.to_string() + } else { + anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); + }; + let res: AnthropicChatResponse = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `completions_endpoint` to use completions")?, + ) + .header("x-api-key", token) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .json(&json!({ + "model": self.configuration.model, + "system": system_prompt, + "max_tokens": max_tokens, + "top_p": self.configuration.top_p, + "temperature": self.configuration.temperature, + "messages": messages + })) + .send() + .await? + .json() + .await?; + if let Some(error) = res.error { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(mut content) = res.content { + Ok(std::mem::take(&mut content[0].text)) + } else { + anyhow::bail!("Uknown error while making request to OpenAI") + } + } + + async fn do_get_chat( + &self, + prompt: &Prompt, + messages: &[ChatMessage], + max_tokens: usize, + ) -> anyhow::Result { + let mut messages = format_chat_messages(messages, prompt); + if messages[0].role != "system" { + anyhow::bail!( + "When using Anthropic, the first message in chat must have role = `system`" + ) + } + let system_prompt = messages.remove(0).content; + self.get_chat(system_prompt, messages, max_tokens).await + } +} + +#[async_trait::async_trait] +impl TransformerBackend for Anthropic { + #[instrument(skip(self))] + async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { + eprintln!("--------------{:?}---------------", prompt); + let max_tokens = self.configuration.max_tokens.completion; + let insert_text = match &self.configuration.chat.completion { + Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?, + None => { + anyhow::bail!("Please provide `anthropic->chat->completion` messages") + } + }; + Ok(DoCompletionResponse { insert_text }) + } + + #[instrument(skip(self))] + async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { + eprintln!("--------------{:?}---------------", prompt); + let max_tokens = self.configuration.max_tokens.generation; + let generated_text = match &self.configuration.chat.generation { + Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?, + None => { + anyhow::bail!("Please provide `anthropic->chat->generation` messages") + } + }; + Ok(DoGenerateResponse { generated_text }) + } + + #[instrument(skip(self))] + async fn do_generate_stream( + &self, + request: &GenerateStreamRequest, + ) -> anyhow::Result { + unimplemented!() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn anthropic_chat_do_completion() -> anyhow::Result<()> { + let configuration: configuration::Anthropic = serde_json::from_value(json!({ + "chat_endpoint": "https://api.anthropic.com/v1/messages", + "model": "claude-3-haiku-20240307", + "auth_token_env_var_name": "ANTHROPIC_API_KEY", + "chat": { + "completion": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ], + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let anthropic = Anthropic::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = anthropic.do_completion(&prompt).await?; + assert!(!response.insert_text.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn anthropic_chat_do_generate() -> anyhow::Result<()> { + let configuration: configuration::Anthropic = serde_json::from_value(json!({ + "chat_endpoint": "https://api.anthropic.com/v1/messages", + "model": "claude-3-haiku-20240307", + "auth_token_env_var_name": "ANTHROPIC_API_KEY", + "chat": { + "generation": [ + { + "role": "system", + "content": "You are a coding assistant. You job is to generate a code snippet to replace .\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the location.\n- Only respond with code. Do not respond with anything that is not valid code." + }, + { + "role": "user", + "content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}" + } + ] + }, + "max_tokens": { + "completion": 16, + "generation": 64 + }, + "max_context": 4096 + }))?; + let anthropic = Anthropic::new(configuration); + let prompt = Prompt::default_with_cursor(); + let response = anthropic.do_generate(&prompt).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } +} diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index 4883cfb..bbf5386 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -6,6 +6,7 @@ use crate::{ }, }; +mod anthropic; mod llama_cpp; mod openai; @@ -30,6 +31,9 @@ impl TryFrom for Box { ValidTransformerBackend::OpenAI(openai_config) => { Ok(Box::new(openai::OpenAI::new(openai_config))) } + ValidTransformerBackend::Anthropic(anthropic_config) => { + Ok(Box::new(anthropic::Anthropic::new(anthropic_config))) + } } } } diff --git a/src/transformer_backends/openai/mod.rs b/src/transformer_backends/openai/mod.rs index 5fcc2d6..29c75ee 100644 --- a/src/transformer_backends/openai/mod.rs +++ b/src/transformer_backends/openai/mod.rs @@ -49,22 +49,26 @@ impl OpenAI { Self { configuration } } - fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result { - eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt); - let client = reqwest::blocking::Client::new(); - let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { - std::env::var(env_var_name)? + fn get_token(&self) -> anyhow::Result { + if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + Ok(std::env::var(env_var_name)?) } else if let Some(token) = &self.configuration.auth_token { - token.to_string() + Ok(token.to_string()) } else { - anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); - }; + anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API") + } + } + + async fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result { + eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt); + let client = reqwest::Client::new(); + let token = self.get_token()?; let res: OpenAICompletionsResponse = client .post( self.configuration .completions_endpoint .as_ref() - .context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?, + .context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?, ) .bearer_auth(token) .header("Content-Type", "application/json") @@ -80,30 +84,28 @@ impl OpenAI { "echo": false, "prompt": prompt })) - .send()? - .json()?; + .send().await? + .json().await?; if let Some(error) = res.error { anyhow::bail!("{:?}", error.to_string()) - } else if let Some(choices) = res.choices { - Ok(choices[0].text.clone()) + } else if let Some(mut choices) = res.choices { + Ok(std::mem::take(&mut choices[0].text)) } else { anyhow::bail!("Uknown error while making request to OpenAI") } } - fn get_chat(&self, messages: Vec, max_tokens: usize) -> anyhow::Result { + async fn get_chat( + &self, + messages: Vec, + max_tokens: usize, + ) -> anyhow::Result { eprintln!( "SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******", messages ); - let client = reqwest::blocking::Client::new(); - let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { - std::env::var(env_var_name)? - } else if let Some(token) = &self.configuration.auth_token { - token.to_string() - } else { - anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API"); - }; + let client = reqwest::Client::new(); + let token = self.get_token()?; let res: OpenAIChatResponse = client .post( self.configuration @@ -124,8 +126,10 @@ impl OpenAI { "temperature": self.configuration.temperature, "messages": messages })) - .send()? - .json()?; + .send() + .await? + .json() + .await?; if let Some(error) = res.error { anyhow::bail!("{:?}", error.to_string()) } else if let Some(choices) = res.choices { @@ -134,6 +138,27 @@ impl OpenAI { anyhow::bail!("Uknown error while making request to OpenAI") } } + + async fn do_chat_completion( + &self, + prompt: &Prompt, + messages: Option<&Vec>, + max_tokens: usize, + ) -> anyhow::Result { + match messages { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, prompt); + self.get_chat(messages, max_tokens).await + } + None => { + self.get_completion( + &format_context_code(&prompt.context, &prompt.code), + max_tokens, + ) + .await + } + } + } } #[async_trait::async_trait] @@ -142,22 +167,14 @@ impl TransformerBackend for OpenAI { async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); let max_tokens = self.configuration.max_tokens.completion; - let insert_text = match &self.configuration.chat { - Some(c) => match &c.completion { - Some(completion_messages) => { - let messages = format_chat_messages(completion_messages, prompt); - self.get_chat(messages, max_tokens)? - } - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }, - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }; + let messages = self + .configuration + .chat + .as_ref() + .and_then(|c| c.completion.as_ref()); + let insert_text = self + .do_chat_completion(prompt, messages, max_tokens) + .await?; Ok(DoCompletionResponse { insert_text }) } @@ -165,22 +182,14 @@ impl TransformerBackend for OpenAI { async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result { eprintln!("--------------{:?}---------------", prompt); let max_tokens = self.configuration.max_tokens.generation; - let generated_text = match &self.configuration.chat { - Some(c) => match &c.generation { - Some(completion_messages) => { - let messages = format_chat_messages(completion_messages, prompt); - self.get_chat(messages, max_tokens)? - } - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }, - None => self.get_completion( - &format_context_code(&prompt.context, &prompt.code), - max_tokens, - )?, - }; + let messages = self + .configuration + .chat + .as_ref() + .and_then(|c| c.generation.as_ref()); + let generated_text = self + .do_chat_completion(prompt, messages, max_tokens) + .await?; Ok(DoGenerateResponse { generated_text }) } @@ -210,7 +219,7 @@ mod test { "max_context": 4096 }))?; let openai = OpenAI::new(configuration); - let prompt = Prompt::default_with_cursor(); + let prompt = Prompt::default_without_cursor(); let response = openai.do_completion(&prompt).await?; assert!(!response.insert_text.is_empty()); Ok(()) @@ -260,7 +269,7 @@ mod test { "max_context": 4096 }))?; let openai = OpenAI::new(configuration); - let prompt = Prompt::default_with_cursor(); + let prompt = Prompt::default_without_cursor(); let response = openai.do_generate(&prompt).await?; assert!(!response.generated_text.is_empty()); Ok(()) @@ -269,7 +278,6 @@ mod test { #[tokio::test] async fn openai_chat_do_generate() -> anyhow::Result<()> { let configuration: configuration::OpenAI = serde_json::from_value(json!({ - "config": { "chat_endpoint": "https://api.openai.com/v1/chat/completions", "model": "gpt-3.5-turbo", "auth_token_env_var_name": "OPENAI_API_KEY", @@ -290,7 +298,7 @@ mod test { "generation": 64 }, "max_context": 4096 - }}))?; + }))?; let openai = OpenAI::new(configuration); let prompt = Prompt::default_with_cursor(); let response = openai.do_generate(&prompt).await?;