mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 23:14:28 +01:00
Compiling
This commit is contained in:
@@ -3,8 +3,6 @@ use serde::{Deserialize, Serialize};
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::memory_backends::PromptForType;
|
|
||||||
|
|
||||||
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
|
||||||
const DEFAULT_OPENAI_MAX_CONTEXT_LENGTH: usize = 2048;
|
const DEFAULT_OPENAI_MAX_CONTEXT_LENGTH: usize = 2048;
|
||||||
|
|
||||||
@@ -187,7 +185,7 @@ pub struct Completion {
|
|||||||
// pub chat: Option<Vec<ChatMessage>>,
|
// pub chat: Option<Vec<ChatMessage>>,
|
||||||
// pub chat_template: Option<String>,
|
// pub chat_template: Option<String>,
|
||||||
// pub chat_format: Option<String>,
|
// pub chat_format: Option<String>,
|
||||||
kwargs: HashMap<String, Value>,
|
pub kwargs: HashMap<String, Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Completion {
|
impl Default for Completion {
|
||||||
@@ -262,26 +260,6 @@ impl Config {
|
|||||||
ValidModel::Anthropic(anthropic) => anthropic.max_requests_per_second,
|
ValidModel::Anthropic(anthropic) => anthropic.max_requests_per_second,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn get_completion_max_context_length(&self) -> anyhow::Result<usize> {
|
|
||||||
// Ok(self.config.completion.max_context_length)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn get_fim(&self) -> Option<&FIM> {
|
|
||||||
// match &self.config.transformer {
|
|
||||||
// ValidModel::LLaMACPP(llama_cpp) => llama_cpp.fim.as_ref(),
|
|
||||||
// ValidModel::OpenAI(openai) => openai.fim.as_ref(),
|
|
||||||
// ValidModel::Anthropic(_) => None,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn get_chat(&self) -> Option<&Chat> {
|
|
||||||
// match &self.config.transformer {
|
|
||||||
// ValidModel::LLaMACPP(llama_cpp) => llama_cpp.chat.as_ref(),
|
|
||||||
// ValidModel::OpenAI(openai) => openai.chat.as_ref(),
|
|
||||||
// ValidModel::Anthropic(anthropic) => Some(&anthropic.chat),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use indexmap::IndexSet;
|
|||||||
use lsp_types::TextDocumentPositionParams;
|
use lsp_types::TextDocumentPositionParams;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@@ -11,7 +12,7 @@ use crate::{
|
|||||||
utils::tokens_to_estimated_characters,
|
utils::tokens_to_estimated_characters,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{MemoryBackend, Prompt, PromptForType};
|
use super::{MemoryBackend, MemoryRunParams, Prompt};
|
||||||
|
|
||||||
pub struct FileStore {
|
pub struct FileStore {
|
||||||
_crawl: bool,
|
_crawl: bool,
|
||||||
@@ -104,33 +105,20 @@ impl FileStore {
|
|||||||
pub fn build_code(
|
pub fn build_code(
|
||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_for_type: PromptForType,
|
params: MemoryRunParams,
|
||||||
max_context_length: usize,
|
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let (mut rope, cursor_index) = self.get_rope_for_position(position, max_context_length)?;
|
let (mut rope, cursor_index) =
|
||||||
|
self.get_rope_for_position(position, params.max_context_length)?;
|
||||||
|
|
||||||
let is_chat_enabled = match prompt_for_type {
|
Ok(match (params.chat.is_some(), params.fim) {
|
||||||
PromptForType::Completion => self
|
|
||||||
.config
|
|
||||||
.get_chat()
|
|
||||||
.map(|c| c.completion.is_some())
|
|
||||||
.unwrap_or(false),
|
|
||||||
PromptForType::Generate => self
|
|
||||||
.config
|
|
||||||
.get_chat()
|
|
||||||
.map(|c| c.generation.is_some())
|
|
||||||
.unwrap_or(false),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(match (is_chat_enabled, self.config.get_fim()) {
|
|
||||||
r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => {
|
r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => {
|
||||||
let max_length = tokens_to_estimated_characters(max_context_length);
|
let max_length = tokens_to_estimated_characters(params.max_context_length);
|
||||||
let start = cursor_index.saturating_sub(max_length / 2);
|
let start = cursor_index.saturating_sub(max_length / 2);
|
||||||
let end = rope
|
let end = rope
|
||||||
.len_chars()
|
.len_chars()
|
||||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||||
|
|
||||||
if is_chat_enabled {
|
if r.0 {
|
||||||
rope.insert(cursor_index, "<CURSOR>");
|
rope.insert(cursor_index, "<CURSOR>");
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
.get_slice(start..end + "<CURSOR>".chars().count())
|
.get_slice(start..end + "<CURSOR>".chars().count())
|
||||||
@@ -154,8 +142,8 @@ impl FileStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let start =
|
let start = cursor_index
|
||||||
cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length));
|
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
|
||||||
let rope_slice = rope
|
let rope_slice = rope
|
||||||
.get_slice(start..cursor_index)
|
.get_slice(start..cursor_index)
|
||||||
.context("Error getting rope slice")?;
|
.context("Error getting rope slice")?;
|
||||||
@@ -190,10 +178,10 @@ impl MemoryBackend for FileStore {
|
|||||||
async fn build_prompt(
|
async fn build_prompt(
|
||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
max_context_length: usize,
|
params: Value,
|
||||||
prompt_for_type: PromptForType,
|
|
||||||
) -> anyhow::Result<Prompt> {
|
) -> anyhow::Result<Prompt> {
|
||||||
let code = self.build_code(position, prompt_for_type, max_context_length)?;
|
let params: MemoryRunParams = serde_json::from_value(params)?;
|
||||||
|
let code = self.build_code(position, params)?;
|
||||||
Ok(Prompt::new("".to_string(), code))
|
Ok(Prompt::new("".to_string(), code))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,26 @@ use lsp_types::{
|
|||||||
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
||||||
TextDocumentPositionParams,
|
TextDocumentPositionParams,
|
||||||
};
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::config::{Config, ValidMemoryBackend};
|
use crate::config::{ChatMessage, Config, ValidMemoryBackend, FIM};
|
||||||
|
|
||||||
pub mod file_store;
|
pub mod file_store;
|
||||||
mod postgresml;
|
mod postgresml;
|
||||||
|
|
||||||
|
const fn max_context_length_default() -> usize {
|
||||||
|
1024
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize)]
|
||||||
|
struct MemoryRunParams {
|
||||||
|
pub fim: Option<FIM>,
|
||||||
|
pub chat: Option<Vec<ChatMessage>>,
|
||||||
|
#[serde(default = "max_context_length_default")]
|
||||||
|
pub max_context_length: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Prompt {
|
pub struct Prompt {
|
||||||
pub context: String,
|
pub context: String,
|
||||||
@@ -20,12 +34,6 @@ impl Prompt {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum PromptForType {
|
|
||||||
Completion,
|
|
||||||
Generate,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
pub trait MemoryBackend {
|
pub trait MemoryBackend {
|
||||||
async fn init(&self) -> anyhow::Result<()> {
|
async fn init(&self) -> anyhow::Result<()> {
|
||||||
@@ -40,8 +48,7 @@ pub trait MemoryBackend {
|
|||||||
async fn build_prompt(
|
async fn build_prompt(
|
||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
max_context_length: usize,
|
params: Value,
|
||||||
prompt_for_type: PromptForType,
|
|
||||||
) -> anyhow::Result<Prompt>;
|
) -> anyhow::Result<Prompt>;
|
||||||
async fn get_filter_text(
|
async fn get_filter_text(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use std::{
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use lsp_types::TextDocumentPositionParams;
|
use lsp_types::TextDocumentPositionParams;
|
||||||
use pgml::{Collection, Pipeline};
|
use pgml::{Collection, Pipeline};
|
||||||
use serde_json::json;
|
use serde_json::{json, Value};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ use crate::{
|
|||||||
utils::tokens_to_estimated_characters,
|
utils::tokens_to_estimated_characters,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{file_store::FileStore, MemoryBackend, Prompt, PromptForType};
|
use super::{file_store::FileStore, MemoryBackend, MemoryRunParams, Prompt};
|
||||||
|
|
||||||
pub struct PostgresML {
|
pub struct PostgresML {
|
||||||
configuration: Config,
|
configuration: Config,
|
||||||
@@ -129,9 +129,9 @@ impl MemoryBackend for PostgresML {
|
|||||||
async fn build_prompt(
|
async fn build_prompt(
|
||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
max_context_length: usize,
|
params: Value,
|
||||||
prompt_for_type: PromptForType,
|
|
||||||
) -> anyhow::Result<Prompt> {
|
) -> anyhow::Result<Prompt> {
|
||||||
|
let params: MemoryRunParams = serde_json::from_value(params)?;
|
||||||
let query = self
|
let query = self
|
||||||
.file_store
|
.file_store
|
||||||
.get_characters_around_position(position, 512)?;
|
.get_characters_around_position(position, 512)?;
|
||||||
@@ -162,8 +162,10 @@ impl MemoryBackend for PostgresML {
|
|||||||
})
|
})
|
||||||
.collect::<anyhow::Result<Vec<String>>>()?
|
.collect::<anyhow::Result<Vec<String>>>()?
|
||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
let code = self.file_store.build_code(position, prompt_for_type, 512)?;
|
let mut file_store_params = params.clone();
|
||||||
let max_characters = tokens_to_estimated_characters(max_context_length);
|
file_store_params.max_context_length = 512;
|
||||||
|
let code = self.file_store.build_code(position, file_store_params)?;
|
||||||
|
let max_characters = tokens_to_estimated_characters(params.max_context_length);
|
||||||
let context: String = context
|
let context: String = context
|
||||||
.chars()
|
.chars()
|
||||||
.take(max_characters - code.chars().count())
|
.take(max_characters - code.chars().count())
|
||||||
|
|||||||
@@ -4,29 +4,27 @@ use lsp_types::{
|
|||||||
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
||||||
TextDocumentPositionParams,
|
TextDocumentPositionParams,
|
||||||
};
|
};
|
||||||
|
use serde_json::Value;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::memory_backends::{MemoryBackend, Prompt, PromptForType};
|
use crate::memory_backends::{MemoryBackend, Prompt};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct PromptRequest {
|
pub struct PromptRequest {
|
||||||
position: TextDocumentPositionParams,
|
position: TextDocumentPositionParams,
|
||||||
max_context_length: usize,
|
params: Value,
|
||||||
prompt_for_type: PromptForType,
|
|
||||||
tx: tokio::sync::oneshot::Sender<Prompt>,
|
tx: tokio::sync::oneshot::Sender<Prompt>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PromptRequest {
|
impl PromptRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
position: TextDocumentPositionParams,
|
position: TextDocumentPositionParams,
|
||||||
max_context_length: usize,
|
params: Value,
|
||||||
prompt_for_type: PromptForType,
|
|
||||||
tx: tokio::sync::oneshot::Sender<Prompt>,
|
tx: tokio::sync::oneshot::Sender<Prompt>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
position,
|
position,
|
||||||
max_context_length,
|
params,
|
||||||
prompt_for_type,
|
|
||||||
tx,
|
tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -69,11 +67,7 @@ async fn do_task(
|
|||||||
}
|
}
|
||||||
WorkerRequest::Prompt(params) => {
|
WorkerRequest::Prompt(params) => {
|
||||||
let prompt = memory_backend
|
let prompt = memory_backend
|
||||||
.build_prompt(
|
.build_prompt(¶ms.position, params.params)
|
||||||
¶ms.position,
|
|
||||||
params.max_context_length,
|
|
||||||
params.prompt_for_type,
|
|
||||||
)
|
|
||||||
.await?;
|
.await?;
|
||||||
params
|
params
|
||||||
.tx
|
.tx
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
utils::format_chat_messages,
|
utils::format_chat_messages,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{RunParams, TransformerBackend};
|
use super::TransformerBackend;
|
||||||
|
|
||||||
const fn max_tokens_default() -> usize {
|
const fn max_tokens_default() -> usize {
|
||||||
64
|
64
|
||||||
@@ -27,10 +27,6 @@ const fn temperature_default() -> f32 {
|
|||||||
0.1
|
0.1
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn max_context_length_default() -> usize {
|
|
||||||
1024
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct AnthropicRunParams {
|
pub struct AnthropicRunParams {
|
||||||
chat: Vec<ChatMessage>,
|
chat: Vec<ChatMessage>,
|
||||||
@@ -40,8 +36,6 @@ pub struct AnthropicRunParams {
|
|||||||
pub top_p: f32,
|
pub top_p: f32,
|
||||||
#[serde(default = "temperature_default")]
|
#[serde(default = "temperature_default")]
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
#[serde(default = "max_context_length_default")]
|
|
||||||
max_context_length: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Anthropic {
|
pub struct Anthropic {
|
||||||
@@ -133,9 +127,10 @@ impl TransformerBackend for Anthropic {
|
|||||||
async fn do_completion(
|
async fn do_completion(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoCompletionResponse> {
|
) -> anyhow::Result<DoCompletionResponse> {
|
||||||
let params: AnthropicRunParams = params.try_into()?;
|
// let params: AnthropicRunParams = params.try_into()?;
|
||||||
|
let params: AnthropicRunParams = serde_json::from_value(params)?;
|
||||||
let insert_text = self.do_get_chat(prompt, params).await?;
|
let insert_text = self.do_get_chat(prompt, params).await?;
|
||||||
Ok(DoCompletionResponse { insert_text })
|
Ok(DoCompletionResponse { insert_text })
|
||||||
}
|
}
|
||||||
@@ -144,9 +139,9 @@ impl TransformerBackend for Anthropic {
|
|||||||
async fn do_generate(
|
async fn do_generate(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationResponse> {
|
) -> anyhow::Result<DoGenerationResponse> {
|
||||||
let params: AnthropicRunParams = params.try_into()?;
|
let params: AnthropicRunParams = serde_json::from_value(params)?;
|
||||||
let generated_text = self.do_get_chat(prompt, params).await?;
|
let generated_text = self.do_get_chat(prompt, params).await?;
|
||||||
Ok(DoGenerationResponse { generated_text })
|
Ok(DoGenerationResponse { generated_text })
|
||||||
}
|
}
|
||||||
@@ -155,7 +150,7 @@ impl TransformerBackend for Anthropic {
|
|||||||
async fn do_generate_stream(
|
async fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerationStreamRequest,
|
request: &GenerationStreamRequest,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use hf_hub::api::sync::ApiBuilder;
|
use hf_hub::api::sync::ApiBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use serde_json::Value;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -17,7 +18,7 @@ use crate::{
|
|||||||
mod model;
|
mod model;
|
||||||
use model::Model;
|
use model::Model;
|
||||||
|
|
||||||
use super::{RunParams, TransformerBackend};
|
use super::TransformerBackend;
|
||||||
|
|
||||||
const fn max_new_tokens_default() -> usize {
|
const fn max_new_tokens_default() -> usize {
|
||||||
32
|
32
|
||||||
@@ -86,9 +87,9 @@ impl TransformerBackend for LLaMACPP {
|
|||||||
async fn do_completion(
|
async fn do_completion(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoCompletionResponse> {
|
) -> anyhow::Result<DoCompletionResponse> {
|
||||||
let params: LLaMACPPRunParams = params.try_into()?;
|
let params: LLaMACPPRunParams = serde_json::from_value(params)?;
|
||||||
let prompt = self.get_prompt_string(prompt, ¶ms)?;
|
let prompt = self.get_prompt_string(prompt, ¶ms)?;
|
||||||
self.model
|
self.model
|
||||||
.complete(&prompt, params)
|
.complete(&prompt, params)
|
||||||
@@ -99,9 +100,9 @@ impl TransformerBackend for LLaMACPP {
|
|||||||
async fn do_generate(
|
async fn do_generate(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationResponse> {
|
) -> anyhow::Result<DoGenerationResponse> {
|
||||||
let params: LLaMACPPRunParams = params.try_into()?;
|
let params: LLaMACPPRunParams = serde_json::from_value(params)?;
|
||||||
let prompt = self.get_prompt_string(prompt, ¶ms)?;
|
let prompt = self.get_prompt_string(prompt, ¶ms)?;
|
||||||
self.model
|
self.model
|
||||||
.complete(&prompt, params)
|
.complete(&prompt, params)
|
||||||
@@ -112,7 +113,7 @@ impl TransformerBackend for LLaMACPP {
|
|||||||
async fn do_generate_stream(
|
async fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
_request: &GenerationStreamRequest,
|
_request: &GenerationStreamRequest,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{self, ValidModel},
|
config::{self, ValidModel},
|
||||||
memory_backends::Prompt,
|
memory_backends::Prompt,
|
||||||
@@ -13,56 +15,47 @@ mod anthropic;
|
|||||||
mod llama_cpp;
|
mod llama_cpp;
|
||||||
mod openai;
|
mod openai;
|
||||||
|
|
||||||
#[derive(Debug)]
|
// impl RunParams {
|
||||||
pub enum RunParams {
|
// pub fn from_completion(completion: &Completion) -> Self {
|
||||||
LLaMACPP(llama_cpp::LLaMACPPRunParams),
|
// todo!()
|
||||||
Anthropic(anthropic::AnthropicRunParams),
|
// }
|
||||||
OpenAI(openai::OpenAIRunParams),
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
impl RunParams {
|
// macro_rules! impl_runparams_try_into {
|
||||||
pub fn from_completion(completion: &Completion) -> Self {
|
// ( $f:ident, $t:ident ) => {
|
||||||
todo!()
|
// impl TryInto<$f> for RunParams {
|
||||||
}
|
// type Error = anyhow::Error;
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! impl_runparams_try_into {
|
// fn try_into(self) -> Result<$f, Self::Error> {
|
||||||
( $f:ident, $t:ident ) => {
|
// match self {
|
||||||
impl TryInto<$f> for RunParams {
|
// Self::$t(a) => Ok(a),
|
||||||
type Error = anyhow::Error;
|
// _ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
fn try_into(self) -> Result<$f, Self::Error> {
|
// impl_runparams_try_into!(AnthropicRunParams, Anthropic);
|
||||||
match self {
|
// impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP);
|
||||||
Self::$t(a) => Ok(a),
|
// impl_runparams_try_into!(OpenAIRunParams, OpenAI);
|
||||||
_ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
impl_runparams_try_into!(AnthropicRunParams, Anthropic);
|
|
||||||
impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP);
|
|
||||||
impl_runparams_try_into!(OpenAIRunParams, OpenAI);
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
pub trait TransformerBackend {
|
pub trait TransformerBackend {
|
||||||
type Test = LLaMACPPRunParams;
|
|
||||||
|
|
||||||
async fn do_completion(
|
async fn do_completion(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoCompletionResponse>;
|
) -> anyhow::Result<DoCompletionResponse>;
|
||||||
async fn do_generate(
|
async fn do_generate(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationResponse>;
|
) -> anyhow::Result<DoGenerationResponse>;
|
||||||
async fn do_generate_stream(
|
async fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerationStreamRequest,
|
request: &GenerationStreamRequest,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationStreamResponse>;
|
) -> anyhow::Result<DoGenerationStreamResponse>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
utils::{format_chat_messages, format_context_code},
|
utils::{format_chat_messages, format_context_code},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{RunParams, TransformerBackend};
|
use super::TransformerBackend;
|
||||||
|
|
||||||
const fn max_tokens_default() -> usize {
|
const fn max_tokens_default() -> usize {
|
||||||
64
|
64
|
||||||
@@ -35,10 +35,6 @@ const fn temperature_default() -> f32 {
|
|||||||
0.1
|
0.1
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn max_context_length_default() -> usize {
|
|
||||||
1024
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct OpenAIRunParams {
|
pub struct OpenAIRunParams {
|
||||||
pub fim: Option<FIM>,
|
pub fim: Option<FIM>,
|
||||||
@@ -53,8 +49,6 @@ pub struct OpenAIRunParams {
|
|||||||
pub frequency_penalty: f32,
|
pub frequency_penalty: f32,
|
||||||
#[serde(default = "temperature_default")]
|
#[serde(default = "temperature_default")]
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
#[serde(default = "max_context_length_default")]
|
|
||||||
max_context_length: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAI {
|
pub struct OpenAI {
|
||||||
@@ -202,9 +196,10 @@ impl TransformerBackend for OpenAI {
|
|||||||
async fn do_completion(
|
async fn do_completion(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoCompletionResponse> {
|
) -> anyhow::Result<DoCompletionResponse> {
|
||||||
let params: OpenAIRunParams = params.try_into()?;
|
// let params: OpenAIRunParams = params.try_into()?;
|
||||||
|
let params: OpenAIRunParams = serde_json::from_value(params)?;
|
||||||
let insert_text = self.do_chat_completion(prompt, params).await?;
|
let insert_text = self.do_chat_completion(prompt, params).await?;
|
||||||
Ok(DoCompletionResponse { insert_text })
|
Ok(DoCompletionResponse { insert_text })
|
||||||
}
|
}
|
||||||
@@ -214,9 +209,9 @@ impl TransformerBackend for OpenAI {
|
|||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
|
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationResponse> {
|
) -> anyhow::Result<DoGenerationResponse> {
|
||||||
let params: OpenAIRunParams = params.try_into()?;
|
let params: OpenAIRunParams = serde_json::from_value(params)?;
|
||||||
let generated_text = self.do_chat_completion(prompt, params).await?;
|
let generated_text = self.do_chat_completion(prompt, params).await?;
|
||||||
Ok(DoGenerationResponse { generated_text })
|
Ok(DoGenerationResponse { generated_text })
|
||||||
}
|
}
|
||||||
@@ -225,7 +220,7 @@ impl TransformerBackend for OpenAI {
|
|||||||
async fn do_generate_stream(
|
async fn do_generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerationStreamRequest,
|
request: &GenerationStreamRequest,
|
||||||
params: RunParams,
|
params: Value,
|
||||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ use tracing::{error, instrument};
|
|||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
|
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
|
||||||
use crate::custom_requests::generation_stream::GenerationStreamParams;
|
use crate::custom_requests::generation_stream::GenerationStreamParams;
|
||||||
use crate::memory_backends::PromptForType;
|
|
||||||
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
||||||
use crate::transformer_backends::TransformerBackend;
|
use crate::transformer_backends::TransformerBackend;
|
||||||
use crate::utils::ToResponseError;
|
use crate::utils::ToResponseError;
|
||||||
@@ -219,15 +218,13 @@ async fn do_completion(
|
|||||||
// TODO: Fix this
|
// TODO: Fix this
|
||||||
// we need to be subtracting the completion / generation tokens from max_context_length
|
// we need to be subtracting the completion / generation tokens from max_context_length
|
||||||
// not sure if we should be doing that for the chat maybe leave a note here for that?
|
// not sure if we should be doing that for the chat maybe leave a note here for that?
|
||||||
// let max_context_length = config.get_completion_max_context_length()?;
|
|
||||||
|
|
||||||
let params:
|
let params = serde_json::to_value(config.config.completion.kwargs.clone()).unwrap();
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
|
||||||
request.params.text_document_position.clone(),
|
request.params.text_document_position.clone(),
|
||||||
max_context_length,
|
params.clone(),
|
||||||
PromptForType::Completion,
|
|
||||||
tx,
|
tx,
|
||||||
)))?;
|
)))?;
|
||||||
let prompt = rx.await?;
|
let prompt = rx.await?;
|
||||||
@@ -238,7 +235,7 @@ async fn do_completion(
|
|||||||
))?;
|
))?;
|
||||||
let filter_text = rx.await?;
|
let filter_text = rx.await?;
|
||||||
|
|
||||||
let response = transformer_backend.do_completion(&prompt).await?;
|
let response = transformer_backend.do_completion(&prompt, params).await?;
|
||||||
let completion_text_edit = TextEdit::new(
|
let completion_text_edit = TextEdit::new(
|
||||||
Range::new(
|
Range::new(
|
||||||
Position::new(
|
Position::new(
|
||||||
|
|||||||
Reference in New Issue
Block a user