mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 15:04:29 +01:00
Overhaul done
This commit is contained in:
@@ -19,6 +19,7 @@ pub enum ValidMemoryBackend {
|
||||
pub enum ValidTransformerBackend {
|
||||
LlamaCPP(ModelGGUF),
|
||||
OpenAI(OpenAI),
|
||||
Anthropic(Anthropic),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -36,6 +37,7 @@ pub struct Chat {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct FIM {
|
||||
pub start: String,
|
||||
pub middle: String,
|
||||
@@ -145,7 +147,7 @@ const fn openai_max_context() -> usize {
|
||||
DEFAULT_OPENAI_MAX_CONTEXT
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct OpenAI {
|
||||
// The auth token env var name
|
||||
pub auth_token_env_var_name: Option<String>,
|
||||
@@ -176,9 +178,35 @@ pub struct OpenAI {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Anthropic {
|
||||
// The auth token env var name
|
||||
pub auth_token_env_var_name: Option<String>,
|
||||
pub auth_token: Option<String>,
|
||||
// The completions endpoint
|
||||
pub completions_endpoint: Option<String>,
|
||||
// The chat endpoint
|
||||
pub chat_endpoint: Option<String>,
|
||||
// The model name
|
||||
pub model: String,
|
||||
// Fill in the middle support
|
||||
pub fim: Option<FIM>,
|
||||
// The maximum number of new tokens to generate
|
||||
#[serde(default)]
|
||||
pub max_tokens: MaxTokens,
|
||||
// Chat args
|
||||
pub chat: Chat,
|
||||
// System prompt
|
||||
#[serde(default = "openai_top_p_default")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "openai_temperature")]
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ValidTransformerConfiguration {
|
||||
openai: Option<OpenAI>,
|
||||
anthropic: Option<Anthropic>,
|
||||
model_gguf: Option<ModelGGUF>,
|
||||
}
|
||||
|
||||
@@ -186,6 +214,7 @@ impl Default for ValidTransformerConfiguration {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_gguf: Some(ModelGGUF::default()),
|
||||
anthropic: None,
|
||||
openai: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,10 +72,6 @@ fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This main loop is tricky
|
||||
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
|
||||
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
|
||||
// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations
|
||||
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
// Build our configuration
|
||||
let configuration = Configuration::new(args)?;
|
||||
|
||||
@@ -73,4 +73,11 @@ impl Prompt {
|
||||
code: r#"def test_code():\n <CURSOR>"#.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_without_cursor() -> Self {
|
||||
Self {
|
||||
context: r#"def test_context():\n pass"#.to_string(),
|
||||
code: r#"def test_code():\n "#.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
207
src/transformer_backends/anthropic.rs
Normal file
207
src/transformer_backends/anthropic.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
configuration::{self, ChatMessage},
|
||||
memory_backends::Prompt,
|
||||
transformer_worker::{
|
||||
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
|
||||
},
|
||||
utils::format_chat_messages,
|
||||
};
|
||||
|
||||
use super::TransformerBackend;
|
||||
|
||||
pub struct Anthropic {
|
||||
configuration: configuration::Anthropic,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AnthropicChatMessage {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AnthropicChatResponse {
|
||||
content: Option<Vec<AnthropicChatMessage>>,
|
||||
error: Option<Value>,
|
||||
}
|
||||
|
||||
impl Anthropic {
|
||||
#[instrument]
|
||||
pub fn new(configuration: configuration::Anthropic) -> Self {
|
||||
Self { configuration }
|
||||
}
|
||||
|
||||
async fn get_chat(
|
||||
&self,
|
||||
system_prompt: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
eprintln!(
|
||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||
messages
|
||||
);
|
||||
let client = reqwest::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
let res: AnthropicChatResponse = client
|
||||
.post(
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions")?,
|
||||
)
|
||||
.header("x-api-key", token)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&json!({
|
||||
"model": self.configuration.model,
|
||||
"system": system_prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": self.configuration.top_p,
|
||||
"temperature": self.configuration.temperature,
|
||||
"messages": messages
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(mut content) = res.content {
|
||||
Ok(std::mem::take(&mut content[0].text))
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_get_chat(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
messages: &[ChatMessage],
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = format_chat_messages(messages, prompt);
|
||||
if messages[0].role != "system" {
|
||||
anyhow::bail!(
|
||||
"When using Anthropic, the first message in chat must have role = `system`"
|
||||
)
|
||||
}
|
||||
let system_prompt = messages.remove(0).content;
|
||||
self.get_chat(system_prompt, messages, max_tokens).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransformerBackend for Anthropic {
|
||||
#[instrument(skip(self))]
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let max_tokens = self.configuration.max_tokens.completion;
|
||||
let insert_text = match &self.configuration.chat.completion {
|
||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||
None => {
|
||||
anyhow::bail!("Please provide `anthropic->chat->completion` messages")
|
||||
}
|
||||
};
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let max_tokens = self.configuration.max_tokens.generation;
|
||||
let generated_text = match &self.configuration.chat.generation {
|
||||
Some(messages) => self.do_get_chat(prompt, messages, max_tokens).await?,
|
||||
None => {
|
||||
anyhow::bail!("Please provide `anthropic->chat->generation` messages")
|
||||
}
|
||||
};
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn do_generate_stream(
|
||||
&self,
|
||||
request: &GenerateStreamRequest,
|
||||
) -> anyhow::Result<DoGenerateStreamResponse> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn anthropic_chat_do_completion() -> anyhow::Result<()> {
|
||||
let configuration: configuration::Anthropic = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
"chat": {
|
||||
"completion": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
],
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let anthropic = Anthropic::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = anthropic.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn anthropic_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::Anthropic = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
|
||||
"chat": {
|
||||
"generation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a coding assistant. You job is to generate a code snippet to replace <CURSOR>.\n\nYour instructions are to:\n- Analyze the provided [Context Code] and [Current Code].\n- Generate a concise code snippet that can replace the <cursor> marker in the [Current Code].\n- Do not provide any explanations or modify any code above or below the <CURSOR> position.\n- The generated code should seamlessly fit into the existing code structure and context.\n- Ensure your answer is properly indented and formatted based on the <CURSOR> location.\n- Only respond with code. Do not respond with anything that is not valid code."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Context code]:\n{CONTEXT}\n\n[Current code]:{CODE}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"max_tokens": {
|
||||
"completion": 16,
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let anthropic = Anthropic::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = anthropic.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
mod anthropic;
|
||||
mod llama_cpp;
|
||||
mod openai;
|
||||
|
||||
@@ -30,6 +31,9 @@ impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send + Sync> {
|
||||
ValidTransformerBackend::OpenAI(openai_config) => {
|
||||
Ok(Box::new(openai::OpenAI::new(openai_config)))
|
||||
}
|
||||
ValidTransformerBackend::Anthropic(anthropic_config) => {
|
||||
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,22 +49,26 @@ impl OpenAI {
|
||||
Self { configuration }
|
||||
}
|
||||
|
||||
fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
|
||||
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
fn get_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
Ok(std::env::var(env_var_name)?)
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
anyhow::bail!("set `auth_token_env_var_name` or `auth_token` in `tranformer->openai` to use an OpenAI compatible API")
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_completion(&self, prompt: &str, max_tokens: usize) -> anyhow::Result<String> {
|
||||
eprintln!("SENDING REQUEST WITH PROMPT: ******\n{}\n******", prompt);
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: OpenAICompletionsResponse = client
|
||||
.post(
|
||||
self.configuration
|
||||
.completions_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `chat_endpoint` and `chat` messages.")?,
|
||||
.context("specify `transformer->openai->completions_endpoint` to use completions. Wanted to use `chat` instead? Please specify `transformer->openai->chat_endpoint` and `transformer->openai->chat` messages.")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -80,30 +84,28 @@ impl OpenAI {
|
||||
"echo": false,
|
||||
"prompt": prompt
|
||||
}))
|
||||
.send()?
|
||||
.json()?;
|
||||
.send().await?
|
||||
.json().await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(choices) = res.choices {
|
||||
Ok(choices[0].text.clone())
|
||||
} else if let Some(mut choices) = res.choices {
|
||||
Ok(std::mem::take(&mut choices[0].text))
|
||||
} else {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_chat(&self, messages: Vec<ChatMessage>, max_tokens: usize) -> anyhow::Result<String> {
|
||||
async fn get_chat(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
eprintln!(
|
||||
"SENDING CHAT REQUEST WITH PROMPT: ******\n{:?}\n******",
|
||||
messages
|
||||
);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
std::env::var(env_var_name)?
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
token.to_string()
|
||||
} else {
|
||||
anyhow::bail!("Please set `auth_token_env_var_name` or `auth_token` in `openai` to use an OpenAI compatible API");
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: OpenAIChatResponse = client
|
||||
.post(
|
||||
self.configuration
|
||||
@@ -124,8 +126,10 @@ impl OpenAI {
|
||||
"temperature": self.configuration.temperature,
|
||||
"messages": messages
|
||||
}))
|
||||
.send()?
|
||||
.json()?;
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(error) = res.error {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(choices) = res.choices {
|
||||
@@ -134,6 +138,27 @@ impl OpenAI {
|
||||
anyhow::bail!("Uknown error while making request to OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_chat_completion(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
messages: Option<&Vec<ChatMessage>>,
|
||||
max_tokens: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
match messages {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, max_tokens).await
|
||||
}
|
||||
None => {
|
||||
self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -142,22 +167,14 @@ impl TransformerBackend for OpenAI {
|
||||
async fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let max_tokens = self.configuration.max_tokens.completion;
|
||||
let insert_text = match &self.configuration.chat {
|
||||
Some(c) => match &c.completion {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, max_tokens)?
|
||||
}
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
},
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
};
|
||||
let messages = self
|
||||
.configuration
|
||||
.chat
|
||||
.as_ref()
|
||||
.and_then(|c| c.completion.as_ref());
|
||||
let insert_text = self
|
||||
.do_chat_completion(prompt, messages, max_tokens)
|
||||
.await?;
|
||||
Ok(DoCompletionResponse { insert_text })
|
||||
}
|
||||
|
||||
@@ -165,22 +182,14 @@ impl TransformerBackend for OpenAI {
|
||||
async fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
|
||||
eprintln!("--------------{:?}---------------", prompt);
|
||||
let max_tokens = self.configuration.max_tokens.generation;
|
||||
let generated_text = match &self.configuration.chat {
|
||||
Some(c) => match &c.generation {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, prompt);
|
||||
self.get_chat(messages, max_tokens)?
|
||||
}
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
},
|
||||
None => self.get_completion(
|
||||
&format_context_code(&prompt.context, &prompt.code),
|
||||
max_tokens,
|
||||
)?,
|
||||
};
|
||||
let messages = self
|
||||
.configuration
|
||||
.chat
|
||||
.as_ref()
|
||||
.and_then(|c| c.generation.as_ref());
|
||||
let generated_text = self
|
||||
.do_chat_completion(prompt, messages, max_tokens)
|
||||
.await?;
|
||||
Ok(DoGenerateResponse { generated_text })
|
||||
}
|
||||
|
||||
@@ -210,7 +219,7 @@ mod test {
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let response = openai.do_completion(&prompt).await?;
|
||||
assert!(!response.insert_text.is_empty());
|
||||
Ok(())
|
||||
@@ -260,7 +269,7 @@ mod test {
|
||||
"max_context": 4096
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let response = openai.do_generate(&prompt).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
@@ -269,7 +278,6 @@ mod test {
|
||||
#[tokio::test]
|
||||
async fn openai_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: configuration::OpenAI = serde_json::from_value(json!({
|
||||
"config": {
|
||||
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"auth_token_env_var_name": "OPENAI_API_KEY",
|
||||
@@ -290,7 +298,7 @@ mod test {
|
||||
"generation": 64
|
||||
},
|
||||
"max_context": 4096
|
||||
}}))?;
|
||||
}))?;
|
||||
let openai = OpenAI::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let response = openai.do_generate(&prompt).await?;
|
||||
|
||||
Reference in New Issue
Block a user