Overhaul done

This commit is contained in:
SilasMarvin
2024-03-23 19:01:05 -07:00
parent 2f71a4de3e
commit fa8e19c1ce
6 changed files with 316 additions and 65 deletions

View File

@@ -19,6 +19,7 @@ pub enum ValidMemoryBackend {
pub enum ValidTransformerBackend {
LlamaCPP(ModelGGUF),
OpenAI(OpenAI),
Anthropic(Anthropic),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -36,6 +37,7 @@ pub struct Chat {
}
#[derive(Clone, Debug, Deserialize)]
#[allow(clippy::upper_case_acronyms)]
pub struct FIM {
pub start: String,
pub middle: String,
@@ -145,7 +147,7 @@ const fn openai_max_context() -> usize {
DEFAULT_OPENAI_MAX_CONTEXT
}
#[derive(Clone, Debug, Default, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct OpenAI {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
@@ -176,9 +178,35 @@ pub struct OpenAI {
max_context: usize,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Anthropic {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
pub auth_token: Option<String>,
// The completions endpoint
pub completions_endpoint: Option<String>,
// The chat endpoint
pub chat_endpoint: Option<String>,
// The model name
pub model: String,
// Fill in the middle support
pub fim: Option<FIM>,
// The maximum number of new tokens to generate
#[serde(default)]
pub max_tokens: MaxTokens,
// Chat args
pub chat: Chat,
// System prompt
#[serde(default = "openai_top_p_default")]
pub top_p: f32,
#[serde(default = "openai_temperature")]
pub temperature: f32,
}
#[derive(Clone, Debug, Deserialize)]
struct ValidTransformerConfiguration {
openai: Option<OpenAI>,
anthropic: Option<Anthropic>,
model_gguf: Option<ModelGGUF>,
}
@@ -186,6 +214,7 @@ impl Default for ValidTransformerConfiguration {
fn default() -> Self {
Self {
model_gguf: Some(ModelGGUF::default()),
anthropic: None,
openai: None,
}
}

View File

@@ -72,10 +72,6 @@ fn main() -> Result<()> {
Ok(())
}
// This main loop is tricky
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
// Build our configuration
let configuration = Configuration::new(args)?;

View File

@@ -73,4 +73,11 @@ impl Prompt {
code: r#"def test_code():\n <CURSOR>"#.to_string(),
}
}
pub fn default_without_cursor() -> Self {
Self {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n "#.to_string(),
}
}
}

View File

@@ -0,0 +1,207 @@
use anyhow::Context;
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::instrument;
use crate::{
configuration::{self, ChatMessage},
memory_backends::Prompt,
transformer_worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
utils::format_chat_messages,
};
use super::TransformerBackend;
pub struct Anthropic {
configuration: configuration::Anthropic,
}
#[derive(Deserialize)]
struct AnthropicChatMessage {
text: String,
}
#[derive(Deserialize)]
struct AnthropicChatResponse {
content: Option<Vec<AnthropicChatMessage>>,
error: Option<Value>,
}
impl Anthropic {
#[instrument]
pub fn new(configuration: configuration::Anthropic) -> Self {
Self { configuration }
}
async fn get_chat(
&self,
system_prompt: String,
messages: Vec<ChatMessage>,
max_tokens: usize,
) -> anyhow::Result<String> {
eprintln!(
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
messages
);
let client = reqwest::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
} else if let Some(token) = &self.configuration.auth_token {
token.to_string()
} else {
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
};
let res: AnthropicChatResponse = client
.post(
self.configuration
.chat_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use completions")?,
)
.header("x-api-key", token)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"model": self.configuration.model,
"system": system_prompt,
"max_tokens": max_tokens,
"top_p": self.configuration.top_p,
"temperature": self.configuration.temperature,
"messages": messages
}))
.send()
.await?
.json()
.await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(mut content) = res.content {
Ok(std::mem::take(&mut content[0].text))
} else {
anyhow::bail!("Uknown error while making request to OpenAI")
}
}
async fn do_get_chat(
&self,
prompt: &Prompt,
messages: &[ChatMessage],
max_tokens: usize,
) -> anyhow::Result<String> {
let mut messages = format_chat_messages(messages, prompt);
if messages[0].role != "system" {
anyhow::bail!(
"When using Anthropic, the first message in chat must have role = `system`"
)
}
let system_prompt = messages.remove(0).content;
self.get_chat(system_prompt, messages, max_tokens).await
}
}
#[async_trait::async_trait]
impl TransformerBackend for Anthropic {
#[instrument(skip(self))]
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
eprintln!("--------------{:?}---------------", prompt);
let max_tokens = self.configuration.max_tokens.completion;
let insert_text = match &self.configuration.chat.completion {
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
None => {
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
}
};
Ok(DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
eprintln!("--------------{:?}---------------", prompt);
let max_tokens = self.configuration.max_tokens.generation;
let generated_text = match &self.configuration.chat.generation {
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
None => {
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
}
};
Ok(DoGenerateResponse { generated_text })
}
#[instrument(skip(self))]
async fn do_generate_stream(
&self,
request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse> {
unimplemented!()
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn anthropic_chat_do_completion() -> anyhow::Result<()> {
let configuration: configuration::Anthropic = serde_json::from_value(json!({
"chat_endpoint": "https://api.anthropic.com/v1/messages",
"model": "claude-3-haiku-20240307",
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
"chat": {
"completion": [
{
"role": "system",
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
},
{
"role": "user",
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
}
],
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let anthropic = Anthropic::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = anthropic.do_completion(&prompt).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn anthropic_chat_do_generate() -> anyhow::Result<()> {
let configuration: configuration::Anthropic = serde_json::from_value(json!({
"chat_endpoint": "https://api.anthropic.com/v1/messages",
"model": "claude-3-haiku-20240307",
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
"chat": {
"generation": [
{
"role": "system",
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
},
{
"role": "user",
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
}
]
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
}))?;
let anthropic = Anthropic::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = anthropic.do_generate(&prompt).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
}

View File

@@ -6,6 +6,7 @@ use crate::{
},
};
mod anthropic;
mod llama_cpp;
mod openai;
@@ -30,6 +31,9 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
ValidTransformerBackend::OpenAI(openai_config) => {
Ok(Box::new(openai::OpenAI::new(openai_config)))
}
ValidTransformerBackend::Anthropic(anthropic_config) => {
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
}
}
}
}

View File

@@ -49,22 +49,26 @@ impl OpenAI {
Self { configuration }
}
fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
let client = reqwest::blocking::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
fn get_token(&self) -> anyhow::Result<String> {
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
Ok(std::env::var(env_var_name)?)
} else if let Some(token) = &self.configuration.auth_token {
token.to_string()
Ok(token.to_string())
} else {
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
};
anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API")
}
}
async fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: OpenAICompletionsResponse = client
.post(
self.configuration
.completions_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?,
.context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?,
)
.bearer_auth(token)
.header("Content-Type", "application/json")
@@ -80,30 +84,28 @@ impl OpenAI {
"echo": false,
"prompt": prompt
}))
.send()?
.json()?;
.send().await?
.json().await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(choices) = res.choices {
Ok(choices[0].text.clone())
} else if let Some(mut choices) = res.choices {
Ok(std::mem::take(&mut choices[0].text))
} else {
anyhow::bail!("Uknown error while making request to OpenAI")
}
}
fn get_chat(&self, messages: Vec<ChatMessage>, max_tokens: usize) -> anyhow::Result<String> {
async fn get_chat(
&self,
messages: Vec<ChatMessage>,
max_tokens: usize,
) -> anyhow::Result<String> {
eprintln!(
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
messages
);
let client = reqwest::blocking::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
std::env::var(env_var_name)?
} else if let Some(token) = &self.configuration.auth_token {
token.to_string()
} else {
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
};
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: OpenAIChatResponse = client
.post(
self.configuration
@@ -124,8 +126,10 @@ impl OpenAI {
"temperature": self.configuration.temperature,
"messages": messages
}))
.send()?
.json()?;
.send()
.await?
.json()
.await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(choices) = res.choices {
@@ -134,6 +138,27 @@ impl OpenAI {
anyhow::bail!("Uknown error while making request to OpenAI")
}
}
async fn do_chat_completion(
&self,
prompt: &Prompt,
messages: Option<&Vec<ChatMessage>>,
max_tokens: usize,
) -> anyhow::Result<String> {
match messages {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, prompt);
self.get_chat(messages, max_tokens).await
}
None => {
self.get_completion(
&format_context_code(&prompt.context, &prompt.code),
max_tokens,
)
.await
}
}
}
}
#[async_trait::async_trait]
@@ -142,22 +167,14 @@ impl TransformerBackend for OpenAI {
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
eprintln!("--------------{:?}---------------", prompt);
let max_tokens = self.configuration.max_tokens.completion;
let insert_text = match &self.configuration.chat {
Some(c) => match &c.completion {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, prompt);
self.get_chat(messages, max_tokens)?
}
None => self.get_completion(
&format_context_code(&prompt.context, &prompt.code),
max_tokens,
)?,
},
None => self.get_completion(
&format_context_code(&prompt.context, &prompt.code),
max_tokens,
)?,
};
let messages = self
.configuration
.chat
.as_ref()
.and_then(|c| c.completion.as_ref());
let insert_text = self
.do_chat_completion(prompt, messages, max_tokens)
.await?;
Ok(DoCompletionResponse { insert_text })
}
@@ -165,22 +182,14 @@ impl TransformerBackend for OpenAI {
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
eprintln!("--------------{:?}---------------", prompt);
let max_tokens = self.configuration.max_tokens.generation;
let generated_text = match &self.configuration.chat {
Some(c) => match &c.generation {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, prompt);
self.get_chat(messages, max_tokens)?
}
None => self.get_completion(
&format_context_code(&prompt.context, &prompt.code),
max_tokens,
)?,
},
None => self.get_completion(
&format_context_code(&prompt.context, &prompt.code),
max_tokens,
)?,
};
let messages = self
.configuration
.chat
.as_ref()
.and_then(|c| c.generation.as_ref());
let generated_text = self
.do_chat_completion(prompt, messages, max_tokens)
.await?;
Ok(DoGenerateResponse { generated_text })
}
@@ -210,7 +219,7 @@ mod test {
"max_context": 4096
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let prompt = Prompt::default_without_cursor();
let response = openai.do_completion(&prompt).await?;
assert!(!response.insert_text.is_empty());
Ok(())
@@ -260,7 +269,7 @@ mod test {
"max_context": 4096
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let prompt = Prompt::default_without_cursor();
let response = openai.do_generate(&prompt).await?;
assert!(!response.generated_text.is_empty());
Ok(())
@@ -269,7 +278,6 @@ mod test {
#[tokio::test]
async fn openai_chat_do_generate() -> anyhow::Result<()> {
let configuration: configuration::OpenAI = serde_json::from_value(json!({
"config": {
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
"model": "gpt-3.5-turbo",
"auth_token_env_var_name": "OPENAI_API_KEY",
@@ -290,7 +298,7 @@ mod test {
"generation": 64
},
"max_context": 4096
}}))?;
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let response = openai.do_generate(&prompt).await?;