mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 06:54:28 +01:00
Merge pull request #23 from SilasMarvin/silas-add-initial-post-processing
Added initial post processing to remove duplicate start and end characters
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lsp-ai"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them."
|
||||
|
||||
@@ -9,6 +9,21 @@ const fn max_requests_per_second_default() -> f32 {
|
||||
1.
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct PostProcess {
|
||||
pub remove_duplicate_start: bool,
|
||||
pub remove_duplicate_end: bool,
|
||||
}
|
||||
|
||||
impl Default for PostProcess {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
remove_duplicate_start: true,
|
||||
remove_duplicate_end: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub enum ValidMemoryBackend {
|
||||
#[serde(rename = "file_store")]
|
||||
@@ -177,10 +192,12 @@ pub struct Anthropic {
|
||||
pub struct Completion {
|
||||
// The model key to use
|
||||
pub model: String,
|
||||
|
||||
// Args are deserialized by the backend using them
|
||||
#[serde(default)]
|
||||
pub parameters: Kwargs,
|
||||
// Parameters for post processing
|
||||
#[serde(default)]
|
||||
pub post_process: PostProcess,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
@@ -230,6 +247,10 @@ impl Config {
|
||||
self.config.completion.is_some()
|
||||
}
|
||||
|
||||
pub fn get_completions_post_process(&self) -> Option<&PostProcess> {
|
||||
self.config.completion.as_ref().map(|x| &x.post_process)
|
||||
}
|
||||
|
||||
pub fn get_completion_transformer_max_requests_per_second(&self) -> anyhow::Result<f32> {
|
||||
match &self
|
||||
.config
|
||||
@@ -335,6 +356,10 @@ mod test {
|
||||
"options": {
|
||||
"num_predict": 32
|
||||
}
|
||||
},
|
||||
"post_process": {
|
||||
"remove_duplicate_start": true,
|
||||
"remove_duplicate_end": true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,20 +2,27 @@ use lsp_types::TextDocumentPositionParams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::config;
|
||||
|
||||
pub enum Generation {}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationParams {
|
||||
// This field was "mixed-in" from TextDocumentPositionParams
|
||||
#[serde(flatten)]
|
||||
pub text_document_position: TextDocumentPositionParams,
|
||||
// The model key to use
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
// Args are deserialized by the backend using them
|
||||
pub parameters: Value,
|
||||
// Parameters for post processing
|
||||
#[serde(default)]
|
||||
pub post_process: config::PostProcess,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateResult {
|
||||
pub generated_text: String,
|
||||
|
||||
@@ -11,9 +11,10 @@ use std::time::{Duration, SystemTime};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config::{self, Config};
|
||||
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
|
||||
use crate::custom_requests::generation_stream::GenerationStreamParams;
|
||||
use crate::memory_backends::Prompt;
|
||||
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
||||
use crate::transformer_backends::TransformerBackend;
|
||||
use crate::utils::ToResponseError;
|
||||
@@ -85,6 +86,83 @@ pub struct DoGenerationStreamResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
fn post_process_start(response: String, front: &str) -> String {
|
||||
let mut front_match = response.len();
|
||||
loop {
|
||||
if response.len() == 0 || front.ends_with(&response[..front_match]) {
|
||||
break;
|
||||
} else {
|
||||
front_match -= 1;
|
||||
}
|
||||
}
|
||||
if front_match > 0 {
|
||||
(&response[front_match..]).to_owned()
|
||||
} else {
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
fn post_process_end(response: String, back: &str) -> String {
|
||||
let mut back_match = 0;
|
||||
loop {
|
||||
if back_match == response.len() {
|
||||
break;
|
||||
} else if back.starts_with(&response[back_match..]) {
|
||||
break;
|
||||
} else {
|
||||
back_match += 1;
|
||||
}
|
||||
}
|
||||
if back_match > 0 {
|
||||
(&response[..back_match]).to_owned()
|
||||
} else {
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
// Some basic post processing that will clean up duplicate characters at the front and back
|
||||
fn post_process_response(
|
||||
response: String,
|
||||
prompt: &Prompt,
|
||||
config: &config::PostProcess,
|
||||
) -> String {
|
||||
match prompt {
|
||||
Prompt::ContextAndCode(context_and_code) => {
|
||||
if context_and_code.code.contains("<CURSOR>") {
|
||||
let mut split = context_and_code.code.split("<CURSOR>");
|
||||
let response = if config.remove_duplicate_start {
|
||||
post_process_start(response, split.next().unwrap())
|
||||
} else {
|
||||
response
|
||||
};
|
||||
if config.remove_duplicate_end {
|
||||
post_process_end(response, split.next().unwrap())
|
||||
} else {
|
||||
response
|
||||
}
|
||||
} else {
|
||||
if config.remove_duplicate_start {
|
||||
post_process_start(response, &context_and_code.code)
|
||||
} else {
|
||||
response
|
||||
}
|
||||
}
|
||||
}
|
||||
Prompt::FIM(fim) => {
|
||||
let response = if config.remove_duplicate_start {
|
||||
post_process_start(response, &fim.prompt)
|
||||
} else {
|
||||
response
|
||||
};
|
||||
if config.remove_duplicate_end {
|
||||
post_process_end(response, &fim.suffix)
|
||||
} else {
|
||||
response
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>>,
|
||||
memory_tx: std::sync::mpsc::Sender<memory_worker::WorkerRequest>,
|
||||
@@ -249,6 +327,7 @@ async fn do_completion(
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Build the prompt
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||
request.params.text_document_position.clone(),
|
||||
@@ -258,13 +337,22 @@ async fn do_completion(
|
||||
)))?;
|
||||
let prompt = rx.await?;
|
||||
|
||||
// Get the filter text
|
||||
let (tx, rx) = oneshot::channel();
|
||||
memory_backend_tx.send(memory_worker::WorkerRequest::FilterText(
|
||||
FilterRequest::new(request.params.text_document_position.clone(), tx),
|
||||
))?;
|
||||
let filter_text = rx.await?;
|
||||
|
||||
let response = transformer_backend.do_completion(&prompt, params).await?;
|
||||
// Get the response
|
||||
let mut response = transformer_backend.do_completion(&prompt, params).await?;
|
||||
eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text);
|
||||
|
||||
if let Some(post_process) = config.get_completions_post_process() {
|
||||
response.insert_text = post_process_response(response.insert_text, &prompt, &post_process);
|
||||
}
|
||||
|
||||
// Build and send the response
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
@@ -314,7 +402,13 @@ async fn do_generate(
|
||||
)))?;
|
||||
let prompt = rx.await?;
|
||||
|
||||
let response = transformer_backend.do_generate(&prompt, params).await?;
|
||||
let mut response = transformer_backend.do_generate(&prompt, params).await?;
|
||||
response.generated_text = post_process_response(
|
||||
response.generated_text,
|
||||
&prompt,
|
||||
&request.params.post_process,
|
||||
);
|
||||
|
||||
let result = GenerateResult {
|
||||
generated_text: response.generated_text,
|
||||
};
|
||||
@@ -325,3 +419,67 @@ async fn do_generate(
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt};
|
||||
|
||||
#[test]
|
||||
fn test_post_process_fim() {
|
||||
let config = config::PostProcess::default();
|
||||
|
||||
let prompt = Prompt::FIM(FIMPrompt {
|
||||
prompt: "test 1234 ".to_string(),
|
||||
suffix: "ttabc".to_string(),
|
||||
});
|
||||
let response = "4 zz tta".to_string();
|
||||
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||
assert_eq!(new_response, "zz ");
|
||||
|
||||
let prompt = Prompt::FIM(FIMPrompt {
|
||||
prompt: "test".to_string(),
|
||||
suffix: "test".to_string(),
|
||||
});
|
||||
let response = "zzzz".to_string();
|
||||
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||
assert_eq!(new_response, "zzzz");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_post_process_context_and_code() {
|
||||
let config = config::PostProcess::default();
|
||||
|
||||
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||
context: "".to_string(),
|
||||
code: "tt ".to_string(),
|
||||
});
|
||||
let response = "tt abc".to_string();
|
||||
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||
assert_eq!(new_response, "abc");
|
||||
|
||||
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||
context: "".to_string(),
|
||||
code: "ff".to_string(),
|
||||
});
|
||||
let response = "zz".to_string();
|
||||
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||
assert_eq!(new_response, "zz");
|
||||
|
||||
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||
context: "".to_string(),
|
||||
code: "tt <CURSOR> tt".to_string(),
|
||||
});
|
||||
let response = "tt abc tt".to_string();
|
||||
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||
assert_eq!(new_response, "abc");
|
||||
|
||||
let prompt = Prompt::ContextAndCode(ContextAndCodePrompt {
|
||||
context: "".to_string(),
|
||||
code: "d<CURSOR>d".to_string(),
|
||||
});
|
||||
let response = "zz".to_string();
|
||||
let new_response = post_process_response(response.clone(), &prompt, &config);
|
||||
assert_eq!(new_response, "zz");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user