mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-19 07:24:24 +01:00
Merge pull request #23 from SilasMarvin/silas-add-initial-post-processing
Added initial post processing to remove duplicate start and end characters
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lsp-ai"
|
name = "lsp-ai"
|
||||||
version = "0.2.0"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them."
|
description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them."
|
||||||
|
|||||||
@@ -9,6 +9,21 @@ const fn max_requests_per_second_default() -> f32 {
|
|||||||
1.
|
1.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct PostProcess {
|
||||||
|
pub remove_duplicate_start: bool,
|
||||||
|
pub remove_duplicate_end: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PostProcess {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
remove_duplicate_start: true,
|
||||||
|
remove_duplicate_end: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub enum ValidMemoryBackend {
|
pub enum ValidMemoryBackend {
|
||||||
#[serde(rename = "file_store")]
|
#[serde(rename = "file_store")]
|
||||||
@@ -177,10 +192,12 @@ pub struct Anthropic {
|
|||||||
pub struct Completion {
|
pub struct Completion {
|
||||||
// The model key to use
|
// The model key to use
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
||||||
// Args are deserialized by the backend using them
|
// Args are deserialized by the backend using them
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub parameters: Kwargs,
|
pub parameters: Kwargs,
|
||||||
|
// Parameters for post processing
|
||||||
|
#[serde(default)]
|
||||||
|
pub post_process: PostProcess,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
@@ -230,6 +247,10 @@ impl Config {
|
|||||||
self.config.completion.is_some()
|
self.config.completion.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_completions_post_process(&self) -> Option<&PostProcess> {
|
||||||
|
self.config.completion.as_ref().map(|x| &x.post_process)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result<f32> {
|
pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result<f32> {
|
||||||
match &self
|
match &self
|
||||||
.config
|
.config
|
||||||
@@ -335,6 +356,10 @@ mod test {
|
|||||||
"options": {
|
"options": {
|
||||||
"num_predict": 32
|
"num_predict": 32
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"post_process": {
|
||||||
|
"remove_duplicate_start": true,
|
||||||
|
"remove_duplicate_end": true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,20 +2,27 @@ use lsp_types::TextDocumentPositionParams;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::config;
|
||||||
|
|
||||||
pub enum Generation {}
|
pub enum Generation {}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct GenerationParams {
|
pub struct GenerationParams {
|
||||||
// This field was "mixed-in" from TextDocumentPositionParams
|
// This field was "mixed-in" from TextDocumentPositionParams
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub text_document_position: TextDocumentPositionParams,
|
pub text_document_position: TextDocumentPositionParams,
|
||||||
|
// The model key to use
|
||||||
pub model: String,
|
pub model: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
// Args are deserialized by the backend using them
|
||||||
pub parameters: Value,
|
pub parameters: Value,
|
||||||
|
// Parameters for post processing
|
||||||
|
#[serde(default)]
|
||||||
|
pub post_process: config::PostProcess,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct GenerateResult {
|
pub struct GenerateResult {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
|
|||||||
@@ -11,9 +11,10 @@ use std::time::{Duration, SystemTime};
|
|||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{error, instrument};
|
use tracing::{error, instrument};
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::{self, Config};
|
||||||
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
|
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
|
||||||
use crate::custom_requests::generation_stream::GenerationStreamParams;
|
use crate::custom_requests::generation_stream::GenerationStreamParams;
|
||||||
|
use crate::memory_backends::Prompt;
|
||||||
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
||||||
use crate::transformer_backends::TransformerBackend;
|
use crate::transformer_backends::TransformerBackend;
|
||||||
use crate::utils::ToResponseError;
|
use crate::utils::ToResponseError;
|
||||||
@@ -85,6 +86,83 @@ pub struct DoGenerationStreamResponse {
|
|||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn post_process_start(response: String, front: &str) -> String {
|
||||||
|
let mut front_match = response.len();
|
||||||
|
loop {
|
||||||
|
if response.len() == 0 || front.ends_with(&response[..front_match]) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
front_match -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if front_match > 0 {
|
||||||
|
(&response[front_match..]).to_owned()
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_process_end(response: String, back: &str) -> String {
|
||||||
|
let mut back_match = 0;
|
||||||
|
loop {
|
||||||
|
if back_match == response.len() {
|
||||||
|
break;
|
||||||
|
} else if back.starts_with(&response[back_match..]) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
back_match += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if back_match > 0 {
|
||||||
|
(&response[..back_match]).to_owned()
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some basic post processing that will clean up duplicate characters at the front and back
|
||||||
|
fn post_process_response(
|
||||||
|
response: String,
|
||||||
|
prompt: &Prompt,
|
||||||
|
config: &config::PostProcess,
|
||||||
|
) -> String {
|
||||||
|
match prompt {
|
||||||
|
Prompt::ContextAndCode(context_and_code) => {
|
||||||
|
if context_and_code.code.contains("<CURSOR>") {
|
||||||
|
let mut split = context_and_code.code.split("<CURSOR>");
|
||||||
|
let response = if config.remove_duplicate_start {
|
||||||
|
post_process_start(response, split.next().unwrap())
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
};
|
||||||
|
if config.remove_duplicate_end {
|
||||||
|
post_process_end(response, split.next().unwrap())
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if config.remove_duplicate_start {
|
||||||
|
post_process_start(response, &context_and_code.code)
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Prompt::FIM(fim) => {
|
||||||
|
let response = if config.remove_duplicate_start {
|
||||||
|
post_process_start(response, &fim.prompt)
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
};
|
||||||
|
if config.remove_duplicate_end {
|
||||||
|
post_process_end(response, &fim.suffix)
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn run(
|
pub fn run(
|
||||||
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
|
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
|
||||||
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||||
@@ -249,6 +327,7 @@ async fn do_completion(
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
// Build the prompt
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||||
request.params.text_document_position.clone(),
|
request.params.text_document_position.clone(),
|
||||||
@@ -258,13 +337,22 @@ async fn do_completion(
|
|||||||
)))?;
|
)))?;
|
||||||
let prompt = rx.await?;
|
let prompt = rx.await?;
|
||||||
|
|
||||||
|
// Get the filter text
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
memory_backend_tx.send(memory_worker::WorkerRequest::FilterText(
|
memory_backend_tx.send(memory_worker::WorkerRequest::FilterText(
|
||||||
FilterRequest::new(request.params.text_document_position.clone(), tx),
|
FilterRequest::new(request.params.text_document_position.clone(), tx),
|
||||||
))?;
|
))?;
|
||||||
let filter_text = rx.await?;
|
let filter_text = rx.await?;
|
||||||
|
|
||||||
let response = transformer_backend.do_completion(&prompt, params).await?;
|
// Get the response
|
||||||
|
let mut response = transformer_backend.do_completion(&prompt, params).await?;
|
||||||
|
eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text);
|
||||||
|
|
||||||
|
if let Some(post_process) = config.get_completions_post_process() {
|
||||||
|
response.insert_text = post_process_response(response.insert_text, &prompt, &post_process);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build and send the response
|
||||||
let completion_text_edit = TextEdit::new(
|
let completion_text_edit = TextEdit::new(
|
||||||
Range::new(
|
Range::new(
|
||||||
Position::new(
|
Position::new(
|
||||||
@@ -314,7 +402,13 @@ async fn do_generate(
|
|||||||
)))?;
|
)))?;
|
||||||
let prompt = rx.await?;
|
let prompt = rx.await?;
|
||||||
|
|
||||||
let response = transformer_backend.do_generate(&prompt, params).await?;
|
let mut response = transformer_backend.do_generate(&prompt, params).await?;
|
||||||
|
response.generated_text = post_process_response(
|
||||||
|
response.generated_text,
|
||||||
|
&prompt,
|
||||||
|
&request.params.post_process,
|
||||||
|
);
|
||||||
|
|
||||||
let result = GenerateResult {
|
let result = GenerateResult {
|
||||||
generated_text: response.generated_text,
|
generated_text: response.generated_text,
|
||||||
};
|
};
|
||||||
@@ -325,3 +419,67 @@ async fn do_generate(
|
|||||||
error: None,
|
error: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_post_process_fim() {
|
||||||
|
let config = config::PostProcess::default();
|
||||||
|
|
||||||
|
let prompt = Prompt::FIM(FIMPrompt {
|
||||||
|
prompt: "test 1234 ".to_string(),
|
||||||
|
suffix: "ttabc".to_string(),
|
||||||
|
});
|
||||||
|
let response = "4 zz tta".to_string();
|
||||||
|
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||||
|
assert_eq!(new_response, "zz ");
|
||||||
|
|
||||||
|
let prompt = Prompt::FIM(FIMPrompt {
|
||||||
|
prompt: "test".to_string(),
|
||||||
|
suffix: "test".to_string(),
|
||||||
|
});
|
||||||
|
let response = "zzzz".to_string();
|
||||||
|
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||||
|
assert_eq!(new_response, "zzzz");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_post_process_context_and_code() {
|
||||||
|
let config = config::PostProcess::default();
|
||||||
|
|
||||||
|
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||||
|
context: "".to_string(),
|
||||||
|
code: "tt ".to_string(),
|
||||||
|
});
|
||||||
|
let response = "tt abc".to_string();
|
||||||
|
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||||
|
assert_eq!(new_response, "abc");
|
||||||
|
|
||||||
|
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||||
|
context: "".to_string(),
|
||||||
|
code: "ff".to_string(),
|
||||||
|
});
|
||||||
|
let response = "zz".to_string();
|
||||||
|
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||||
|
assert_eq!(new_response, "zz");
|
||||||
|
|
||||||
|
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||||
|
context: "".to_string(),
|
||||||
|
code: "tt <CURSOR> tt".to_string(),
|
||||||
|
});
|
||||||
|
let response = "tt abc tt".to_string();
|
||||||
|
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||||
|
assert_eq!(new_response, "abc");
|
||||||
|
|
||||||
|
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||||
|
context: "".to_string(),
|
||||||
|
code: "d<CURSOR>d".to_string(),
|
||||||
|
});
|
||||||
|
let response = "zz".to_string();
|
||||||
|
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||||
|
assert_eq!(new_response, "zz");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user