Compiling

This commit is contained in:
SilasMarvin
2024-05-12 18:40:51 -07:00
parent 0c04ab08b4
commit 31251e7235
10 changed files with 95 additions and 145 deletions

View File

@@ -3,8 +3,6 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use crate::memory_backends::PromptForType;
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
const DEFAULT_OPENAI_MAX_CONTEXT_LENGTH: usize = 2048;
@@ -187,7 +185,7 @@ pub struct Completion {
// pub chat: Option<Vec<ChatMessage>>,
// pub chat_template: Option<String>,
// pub chat_format: Option<String>,
kwargs: HashMap<String, Value>,
pub kwargs: HashMap<String, Value>,
}
impl Default for Completion {
@@ -262,26 +260,6 @@ impl Config {
ValidModel::Anthropic(anthropic) => anthropic.max_requests_per_second,
}
}
// pub fn get_completion_max_context_length(&self) -> anyhow::Result<usize> {
// Ok(self.config.completion.max_context_length)
// }
// pub fn get_fim(&self) -> Option<&FIM> {
// match &self.config.transformer {
// ValidModel::LLaMACPP(llama_cpp) => llama_cpp.fim.as_ref(),
// ValidModel::OpenAI(openai) => openai.fim.as_ref(),
// ValidModel::Anthropic(_) => None,
// }
// }
// pub fn get_chat(&self) -> Option<&Chat> {
// match &self.config.transformer {
// ValidModel::LLaMACPP(llama_cpp) => llama_cpp.chat.as_ref(),
// ValidModel::OpenAI(openai) => openai.chat.as_ref(),
// ValidModel::Anthropic(anthropic) => Some(&anthropic.chat),
// }
// }
}
#[cfg(test)]

View File

@@ -3,6 +3,7 @@ use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use ropey::Rope;
use serde_json::Value;
use std::collections::HashMap;
use tracing::instrument;
@@ -11,7 +12,7 @@ use crate::{
utils::tokens_to_estimated_characters,
};
use super::{MemoryBackend, Prompt, PromptForType};
use super::{MemoryBackend, MemoryRunParams, Prompt};
pub struct FileStore {
_crawl: bool,
@@ -104,33 +105,20 @@ impl FileStore {
pub fn build_code(
&self,
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
max_context_length: usize,
params: MemoryRunParams,
) -> anyhow::Result<String> {
let (mut rope, cursor_index) = self.get_rope_for_position(position, max_context_length)?;
let (mut rope, cursor_index) =
self.get_rope_for_position(position, params.max_context_length)?;
let is_chat_enabled = match prompt_for_type {
PromptForType::Completion => self
.config
.get_chat()
.map(|c| c.completion.is_some())
.unwrap_or(false),
PromptForType::Generate => self
.config
.get_chat()
.map(|c| c.generation.is_some())
.unwrap_or(false),
};
Ok(match (is_chat_enabled, self.config.get_fim()) {
Ok(match (params.chat.is_some(), params.fim) {
r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => {
let max_length = tokens_to_estimated_characters(max_context_length);
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
if is_chat_enabled {
if r.0 {
rope.insert(cursor_index, "<CURSOR>");
let rope_slice = rope
.get_slice(start..end + "<CURSOR>".chars().count())
@@ -154,8 +142,8 @@ impl FileStore {
}
}
_ => {
let start =
cursor_index.saturating_sub(tokens_to_estimated_characters(max_context_length));
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
@@ -190,10 +178,10 @@ impl MemoryBackend for FileStore {
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
max_context_length: usize,
prompt_for_type: PromptForType,
params: Value,
) -> anyhow::Result<Prompt> {
let code = self.build_code(position, prompt_for_type, max_context_length)?;
let params: MemoryRunParams = serde_json::from_value(params)?;
let code = self.build_code(position, params)?;
Ok(Prompt::new("".to_string(), code))
}

View File

@@ -2,12 +2,26 @@ use lsp_types::{
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
TextDocumentPositionParams,
};
use serde::Deserialize;
use serde_json::Value;
use crate::config::{Config, ValidMemoryBackend};
use crate::config::{ChatMessage, Config, ValidMemoryBackend, FIM};
pub mod file_store;
mod postgresml;
const fn max_context_length_default() -> usize {
1024
}
#[derive(Clone, Deserialize)]
struct MemoryRunParams {
pub fim: Option<FIM>,
pub chat: Option<Vec<ChatMessage>>,
#[serde(default = "max_context_length_default")]
pub max_context_length: usize,
}
#[derive(Debug)]
pub struct Prompt {
pub context: String,
@@ -20,12 +34,6 @@ impl Prompt {
}
}
#[derive(Debug)]
pub enum PromptForType {
Completion,
Generate,
}
#[async_trait::async_trait]
pub trait MemoryBackend {
async fn init(&self) -> anyhow::Result<()> {
@@ -40,8 +48,7 @@ pub trait MemoryBackend {
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
max_context_length: usize,
prompt_for_type: PromptForType,
params: Value,
) -> anyhow::Result<Prompt>;
async fn get_filter_text(
&self,

View File

@@ -6,7 +6,7 @@ use std::{
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use pgml::{Collection, Pipeline};
use serde_json::json;
use serde_json::{json, Value};
use tokio::time;
use tracing::instrument;
@@ -15,7 +15,7 @@ use crate::{
utils::tokens_to_estimated_characters,
};
use super::{file_store::FileStore, MemoryBackend, Prompt, PromptForType};
use super::{file_store::FileStore, MemoryBackend, MemoryRunParams, Prompt};
pub struct PostgresML {
configuration: Config,
@@ -129,9 +129,9 @@ impl MemoryBackend for PostgresML {
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
max_context_length: usize,
prompt_for_type: PromptForType,
params: Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = serde_json::from_value(params)?;
let query = self
.file_store
.get_characters_around_position(position, 512)?;
@@ -162,8 +162,10 @@ impl MemoryBackend for PostgresML {
})
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
let code = self.file_store.build_code(position, prompt_for_type, 512)?;
let max_characters = tokens_to_estimated_characters(max_context_length);
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
let code = self.file_store.build_code(position, file_store_params)?;
let max_characters = tokens_to_estimated_characters(params.max_context_length);
let context: String = context
.chars()
.take(max_characters - code.chars().count())

View File

@@ -4,29 +4,27 @@ use lsp_types::{
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
TextDocumentPositionParams,
};
use serde_json::Value;
use tracing::error;
use crate::memory_backends::{MemoryBackend, Prompt, PromptForType};
use crate::memory_backends::{MemoryBackend, Prompt};
#[derive(Debug)]
pub struct PromptRequest {
position: TextDocumentPositionParams,
max_context_length: usize,
prompt_for_type: PromptForType,
params: Value,
tx: tokio::sync::oneshot::Sender<Prompt>,
}
impl PromptRequest {
pub fn new(
position: TextDocumentPositionParams,
max_context_length: usize,
prompt_for_type: PromptForType,
params: Value,
tx: tokio::sync::oneshot::Sender<Prompt>,
) -> Self {
Self {
position,
max_context_length,
prompt_for_type,
params,
tx,
}
}
@@ -69,11 +67,7 @@ async fn do_task(
}
WorkerRequest::Prompt(params) => {
let prompt = memory_backend
.build_prompt(
&params.position,
params.max_context_length,
params.prompt_for_type,
)
.build_prompt(&params.position, params.params)
.await?;
params
.tx

View File

@@ -13,7 +13,7 @@ use crate::{
utils::format_chat_messages,
};
use super::{RunParams, TransformerBackend};
use super::TransformerBackend;
const fn max_tokens_default() -> usize {
64
@@ -27,10 +27,6 @@ const fn temperature_default() -> f32 {
0.1
}
const fn max_context_length_default() -> usize {
1024
}
#[derive(Debug, Deserialize)]
pub struct AnthropicRunParams {
chat: Vec<ChatMessage>,
@@ -40,8 +36,6 @@ pub struct AnthropicRunParams {
pub top_p: f32,
#[serde(default = "temperature_default")]
pub temperature: f32,
#[serde(default = "max_context_length_default")]
max_context_length: usize,
}
pub struct Anthropic {
@@ -133,9 +127,10 @@ impl TransformerBackend for Anthropic {
async fn do_completion(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoCompletionResponse> {
let params: AnthropicRunParams = params.try_into()?;
// let params: AnthropicRunParams = params.try_into()?;
let params: AnthropicRunParams = serde_json::from_value(params)?;
let insert_text = self.do_get_chat(prompt, params).await?;
Ok(DoCompletionResponse { insert_text })
}
@@ -144,9 +139,9 @@ impl TransformerBackend for Anthropic {
async fn do_generate(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationResponse> {
let params: AnthropicRunParams = params.try_into()?;
let params: AnthropicRunParams = serde_json::from_value(params)?;
let generated_text = self.do_get_chat(prompt, params).await?;
Ok(DoGenerationResponse { generated_text })
}
@@ -155,7 +150,7 @@ impl TransformerBackend for Anthropic {
async fn do_generate_stream(
&self,
request: &GenerationStreamRequest,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationStreamResponse> {
unimplemented!()
}

View File

@@ -1,6 +1,7 @@
use anyhow::Context;
use hf_hub::api::sync::ApiBuilder;
use serde::Deserialize;
use serde_json::Value;
use tracing::instrument;
use crate::{
@@ -17,7 +18,7 @@ use crate::{
mod model;
use model::Model;
use super::{RunParams, TransformerBackend};
use super::TransformerBackend;
const fn max_new_tokens_default() -> usize {
32
@@ -86,9 +87,9 @@ impl TransformerBackend for LLaMACPP {
async fn do_completion(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoCompletionResponse> {
let params: LLaMACPPRunParams = params.try_into()?;
let params: LLaMACPPRunParams = serde_json::from_value(params)?;
let prompt = self.get_prompt_string(prompt, &params)?;
self.model
.complete(&prompt, params)
@@ -99,9 +100,9 @@ impl TransformerBackend for LLaMACPP {
async fn do_generate(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationResponse> {
let params: LLaMACPPRunParams = params.try_into()?;
let params: LLaMACPPRunParams = serde_json::from_value(params)?;
let prompt = self.get_prompt_string(prompt, &params)?;
self.model
.complete(&prompt, params)
@@ -112,7 +113,7 @@ impl TransformerBackend for LLaMACPP {
async fn do_generate_stream(
&self,
_request: &GenerationStreamRequest,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationStreamResponse> {
unimplemented!()
}

View File

@@ -1,3 +1,5 @@
use serde_json::Value;
use crate::{
config::{self, ValidModel},
memory_backends::Prompt,
@@ -13,56 +15,47 @@ mod anthropic;
mod llama_cpp;
mod openai;
#[derive(Debug)]
pub enum RunParams {
LLaMACPP(llama_cpp::LLaMACPPRunParams),
Anthropic(anthropic::AnthropicRunParams),
OpenAI(openai::OpenAIRunParams),
}
// impl RunParams {
// pub fn from_completion(completion: &Completion) -> Self {
// todo!()
// }
// }
impl RunParams {
pub fn from_completion(completion: &Completion) -> Self {
todo!()
}
}
// macro_rules! impl_runparams_try_into {
// ( $f:ident, $t:ident ) => {
// impl TryInto<$f> for RunParams {
// type Error = anyhow::Error;
macro_rules! impl_runparams_try_into {
( $f:ident, $t:ident ) => {
impl TryInto<$f> for RunParams {
type Error = anyhow::Error;
// fn try_into(self) -> Result<$f, Self::Error> {
// match self {
// Self::$t(a) => Ok(a),
// _ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)),
// }
// }
// }
// };
// }
fn try_into(self) -> Result<$f, Self::Error> {
match self {
Self::$t(a) => Ok(a),
_ => anyhow::bail!("Cannot convert RunParams into {}", stringify!($f)),
}
}
}
};
}
impl_runparams_try_into!(AnthropicRunParams, Anthropic);
impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP);
impl_runparams_try_into!(OpenAIRunParams, OpenAI);
// impl_runparams_try_into!(AnthropicRunParams, Anthropic);
// impl_runparams_try_into!(LLaMACPPRunParams, LLaMACPP);
// impl_runparams_try_into!(OpenAIRunParams, OpenAI);
#[async_trait::async_trait]
pub trait TransformerBackend {
type Test = LLaMACPPRunParams;
async fn do_completion(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoCompletionResponse>;
async fn do_generate(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationResponse>;
async fn do_generate_stream(
&self,
request: &GenerationStreamRequest,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationStreamResponse>;
}

View File

@@ -13,7 +13,7 @@ use crate::{
utils::{format_chat_messages, format_context_code},
};
use super::{RunParams, TransformerBackend};
use super::TransformerBackend;
const fn max_tokens_default() -> usize {
64
@@ -35,10 +35,6 @@ const fn temperature_default() -> f32 {
0.1
}
const fn max_context_length_default() -> usize {
1024
}
#[derive(Debug, Deserialize)]
pub struct OpenAIRunParams {
pub fim: Option<FIM>,
@@ -53,8 +49,6 @@ pub struct OpenAIRunParams {
pub frequency_penalty: f32,
#[serde(default = "temperature_default")]
pub temperature: f32,
#[serde(default = "max_context_length_default")]
max_context_length: usize,
}
pub struct OpenAI {
@@ -202,9 +196,10 @@ impl TransformerBackend for OpenAI {
async fn do_completion(
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoCompletionResponse> {
let params: OpenAIRunParams = params.try_into()?;
// let params: OpenAIRunParams = params.try_into()?;
let params: OpenAIRunParams = serde_json::from_value(params)?;
let insert_text = self.do_chat_completion(prompt, params).await?;
Ok(DoCompletionResponse { insert_text })
}
@@ -214,9 +209,9 @@ impl TransformerBackend for OpenAI {
&self,
prompt: &Prompt,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationResponse> {
let params: OpenAIRunParams = params.try_into()?;
let params: OpenAIRunParams = serde_json::from_value(params)?;
let generated_text = self.do_chat_completion(prompt, params).await?;
Ok(DoGenerationResponse { generated_text })
}
@@ -225,7 +220,7 @@ impl TransformerBackend for OpenAI {
async fn do_generate_stream(
&self,
request: &GenerationStreamRequest,
params: RunParams,
params: Value,
) -> anyhow::Result<DoGenerationStreamResponse> {
unimplemented!()
}

View File

@@ -14,7 +14,6 @@ use tracing::{error, instrument};
use crate::config::Config;
use crate::custom_requests::generation::{GenerateResult, GenerationParams};
use crate::custom_requests::generation_stream::GenerationStreamParams;
use crate::memory_backends::PromptForType;
use crate::memory_worker::{self, FilterRequest, PromptRequest};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
@@ -219,15 +218,13 @@ async fn do_completion(
// TODO: Fix this
// we need to be subtracting the completion / generation tokens from max_context_length
// not sure if we should be doing that for the chat maybe leave a note here for that?
// let max_context_length = config.get_completion_max_context_length()?;
let params:
let params = serde_json::to_value(config.config.completion.kwargs.clone()).unwrap();
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
request.params.text_document_position.clone(),
max_context_length,
PromptForType::Completion,
params.clone(),
tx,
)))?;
let prompt = rx.await?;
@@ -238,7 +235,7 @@ async fn do_completion(
))?;
let filter_text = rx.await?;
let response = transformer_backend.do_completion(&prompt).await?;
let response = transformer_backend.do_completion(&prompt, params).await?;
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(