Reorganized

This commit is contained in:
Silas Marvin
2024-01-15 14:59:32 -08:00
parent 85dfdcd90e
commit 79524b6c06
8 changed files with 348 additions and 140 deletions

View File

@@ -12,6 +12,14 @@
"engines": {
"vscode": "^1.75.0"
},
"contributes": {
"commands": [
{
"command": "lsp-ai.generate",
"title": "LSP AI Generate"
}
]
},
"devDependencies": {
"@types/node": "^20.11.0",
"typescript": "^5.3.3"

View File

@@ -1,5 +1,4 @@
import { workspace, ExtensionContext } from 'vscode';
import * as vscode from 'vscode';
import {
LanguageClient,
LanguageClientOptions,
@@ -9,9 +8,7 @@ import {
let client: LanguageClient;
export function activate(_context: ExtensionContext) {
console.log("\n\nIN THE ACTIVATE FUNCTION\n\n");
export function activate(context: vscode.ExtensionContext) {
// Configure the server options
let serverOptions: ServerOptions = {
command: "lsp-ai",
@@ -20,11 +17,7 @@ export function activate(_context: ExtensionContext) {
// Options to control the language client
let clientOptions: LanguageClientOptions = {
documentSelector: [{ scheme: 'file', language: 'python' }],
synchronize: {
// Notify the server about file changes to '.clientrc files contained in the workspace
fileEvents: workspace.createFileSystemWatcher('**/.clientrc')
}
documentSelector: [{ pattern: "**" }],
};
// Create the language client and start the client
@@ -35,10 +28,34 @@ export function activate(_context: ExtensionContext) {
clientOptions
);
console.log("\n\nSTARTING THE CLIENT\n\n");
// Start the client. This will also launch the server
client.start();
client.onRequest("textDocument/completion", (params) => {
console.log("HERE WE GO");
console.log(params);
});
// Register functions
const command = 'lsp-ai.generate';
const commandHandler = () => {
const editor = vscode.window.activeTextEditor;
console.log("SENDING REQUEST FOR GENERATE");
console.log(editor);
let params = {
textDocument: {
uri: editor.document.uri.toString(),
},
position: editor.selection.active
};
console.log(params);
client.sendRequest("textDocument/generate", params).then(result => {
console.log(result);
}).catch(error => {
console.error(error);
});
};
context.subscriptions.push(vscode.commands.registerCommand(command, commandHandler));
}
export function deactivate(): Thenable<void> | undefined {

View File

@@ -0,0 +1,23 @@
use lsp_types::TextDocumentPositionParams;
use serde::{Deserialize, Serialize};
pub enum Generate {}
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateParams {
// This field was "mixed-in" from TextDocumentPositionParams
#[serde(flatten)]
pub text_document_position: TextDocumentPositionParams,
}
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
pub struct GenerateResult {
pub generated_text: String,
}
impl lsp_types::request::Request for Generate {
type Params = GenerateParams;
type Result = GenerateResult;
const METHOD: &'static str = "textDocument/generate";
}

View File

@@ -0,0 +1 @@
pub mod generate;

View File

@@ -1,21 +1,21 @@
use anyhow::Context;
use anyhow::Result;
use core::panic;
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
use anyhow::{Context, Result};
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId};
use lsp_types::{
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
CompletionParams, CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
Position, Range, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
};
use parking_lot::Mutex;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use pyo3::prelude::*;
use ropey::Rope;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use std::thread;
use std::{collections::HashMap, sync::Arc, thread};
use once_cell::sync::Lazy;
use pyo3::prelude::*;
mod custom_requests;
mod worker;
use custom_requests::generate::Generate;
use worker::{CompletionRequest, GenerateRequest, WorkerRequest};
pub static PY_MODULE: Lazy<Result<Py<PyAny>>> = Lazy::new(|| {
pyo3::Python::with_gil(|py| -> Result<Py<PyAny>> {
@@ -67,23 +67,11 @@ fn main() -> Result<()> {
#[derive(Deserialize)]
struct Params {}
struct CompletionRequest {
id: RequestId,
params: CompletionParams,
rope: Rope,
}
impl CompletionRequest {
fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
// This main loop is tricky
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
let params: Params = serde_json::from_value(params)?;
let _params: Params = serde_json::from_value(params)?;
// Set the model
Python::with_gil(|py| -> Result<()> {
@@ -100,98 +88,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
let mut file_map: HashMap<String, Rope> = HashMap::new();
// How we communicate between the worker and receiver threads
let last_completion_request = Arc::new(Mutex::new(None));
let last_worker_request = Arc::new(RwLock::new(None));
// Thread local variables
let thread_last_completion_request = last_completion_request.clone();
let thread_last_worker_request = last_worker_request.clone();
let thread_connection = connection.clone();
thread::spawn(move || {
loop {
// I think we need this drop, not 100% sure though
let mut completion_request = thread_last_completion_request.lock();
let params = std::mem::take(&mut *completion_request);
drop(completion_request);
if let Some(CompletionRequest {
id,
params,
mut rope,
}) = params
{
let filter_text = rope
.get_line(params.text_document_position.position.line as usize)
.expect("Error getting line with ropey")
.to_string();
// Convert rope to correct prompt for llm
let cursor_index = rope
.line_to_char(params.text_document_position.position.line as usize)
+ params.text_document_position.position.character as usize;
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
let prompt = rope
.get_slice((0..cursor_index))
.expect("Error getting rope slice")
.to_string();
eprintln!("\n\n****{prompt}****\n\n");
let insert_text = Python::with_gil(|py| -> Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
Ok(out)
})
.expect("Error during transform");
eprintln!("\n{insert_text}\n");
// Create and return the completion
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
),
insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {insert_text}"),
filter_text: Some(filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
let resp = Response {
id,
result: Some(result),
error: None,
};
thread_connection
.sender
.send(Message::Response(resp))
.expect("Error sending response");
}
thread::sleep(std::time::Duration::from_millis(5));
}
worker::run(thread_last_worker_request, thread_connection);
});
for msg in &connection.receiver {
@@ -200,21 +103,44 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
if connection.handle_shutdown(&req)? {
return Ok(());
}
match cast::<Completion>(req) {
Ok((id, params)) => {
// Get rope
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
// Update the last CompletionRequest
let mut lcr = last_completion_request.lock();
*lcr = Some(CompletionRequest::new(id, params, rope));
continue;
eprintln!(
"NEW REQUEST: \n{}\n",
serde_json::to_string_pretty(&req).unwrap()
);
match req.method.as_str() {
"textDocument/completion" => match cast::<Completion>(req) {
Ok((id, params)) => {
// Get rope
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
// Update the last CompletionRequest
let mut lcr = last_worker_request.write();
let completion_request = CompletionRequest::new(id, params, rope);
*lcr = Some(WorkerRequest::Completion(completion_request));
}
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
Err(ExtractError::MethodMismatch(_req)) => (),
},
"textDocument/generate" => match cast::<Generate>(req) {
Ok((id, params)) => {
// Get rope
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
// Update the last CompletionRequest
let mut lcr = last_worker_request.write();
let completion_request = GenerateRequest::new(id, params, rope);
*lcr = Some(WorkerRequest::Generate(completion_request));
}
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
Err(ExtractError::MethodMismatch(_req)) => (),
}
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
Err(ExtractError::MethodMismatch(req)) => req,
};
_ => eprintln!("lsp-ai currently only supports textDocument/completion and textDocument/generate")
}
}
Message::Notification(not) => {
if notification_is::<lsp_types::notification::DidOpenTextDocument>(&not) {

61
src/worker/completion.rs Normal file
View File

@@ -0,0 +1,61 @@
use lsp_server::ResponseError;
use pyo3::prelude::*;
use super::CompletionRequest;
use crate::PY_MODULE;
pub struct DoCompletionResponse {
pub insert_text: String,
pub filter_text: String,
}
pub fn do_completion(request: &CompletionRequest) -> Result<DoCompletionResponse, ResponseError> {
let filter_text = request
.rope
.get_line(request.params.text_document_position.position.line as usize)
.ok_or(ResponseError {
code: -32603, // Maybe we want a different error code here?
message: "Error getting line in requested document".to_string(),
data: None,
})?
.to_string();
// Convert rope to correct prompt for llm
let cursor_index = request
.rope
.line_to_char(request.params.text_document_position.position.line as usize)
+ request.params.text_document_position.position.character as usize;
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
let prompt = request
.rope
.get_slice(0..cursor_index)
.expect("Error getting rope slice")
.to_string();
eprintln!("\n\n****{prompt}****\n\n");
Python::with_gil(|py| -> anyhow::Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
Ok(out)
})
.map(|insert_text| DoCompletionResponse {
insert_text,
filter_text,
})
.map_err(|e| ResponseError {
code: -32603,
message: e.to_string(),
data: None,
})
}

47
src/worker/generate.rs Normal file
View File

@@ -0,0 +1,47 @@
use lsp_server::ResponseError;
use pyo3::prelude::*;
use super::GenerateRequest;
use crate::PY_MODULE;
pub struct DoGenerateResponse {
pub generated_text: String,
}
pub fn do_generate(request: &GenerateRequest) -> Result<DoGenerateResponse, ResponseError> {
// Convert rope to correct prompt for llm
let cursor_index = request
.rope
.line_to_char(request.params.text_document_position.position.line as usize)
+ request.params.text_document_position.position.character as usize;
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
let prompt = request
.rope
.get_slice(0..cursor_index)
.expect("Error getting rope slice")
.to_string();
eprintln!("\n\n****{prompt}****\n\n");
Python::with_gil(|py| -> anyhow::Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
Ok(out)
})
.map(|generated_text| DoGenerateResponse { generated_text })
.map_err(|e| ResponseError {
code: -32603,
message: e.to_string(),
data: None,
})
}

125
src/worker/mod.rs Normal file
View File

@@ -0,0 +1,125 @@
use lsp_server::{Connection, Message, RequestId, Response};
use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit,
};
use parking_lot::RwLock;
use ropey::Rope;
use std::{sync::Arc, thread};
mod completion;
mod generate;
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use completion::do_completion;
use generate::do_generate;
#[derive(Clone)]
pub struct CompletionRequest {
id: RequestId,
params: CompletionParams,
rope: Rope,
}
impl CompletionRequest {
pub fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
#[derive(Clone)]
pub struct GenerateRequest {
id: RequestId,
params: GenerateParams,
rope: Rope,
}
impl GenerateRequest {
pub fn new(id: RequestId, params: GenerateParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
#[derive(Clone)]
pub enum WorkerRequest {
Completion(CompletionRequest),
Generate(GenerateRequest),
}
pub fn run(last_worker_request: Arc<RwLock<Option<WorkerRequest>>>, connection: Arc<Connection>) {
loop {
let option_worker_request: Option<WorkerRequest> = {
let completion_request = last_worker_request.read();
(*completion_request).clone()
};
if let Some(request) = option_worker_request {
let response = match request {
WorkerRequest::Completion(request) => match do_completion(&request) {
Ok(response) => {
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
),
response.insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {}", response.insert_text),
filter_text: Some(response.filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
completion_text_edit,
)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
Response {
id: request.id,
result: Some(result),
error: None,
}
}
Err(e) => Response {
id: request.id,
result: None,
error: Some(e),
},
},
WorkerRequest::Generate(request) => match do_generate(&request) {
Ok(result) => {
let result = GenerateResult {
generated_text: result.generated_text,
};
let result = serde_json::to_value(&result).unwrap();
Response {
id: request.id,
result: Some(result),
error: None,
}
}
Err(e) => Response {
id: request.id,
result: None,
error: Some(e),
},
},
};
connection
.sender
.send(Message::Response(response))
.expect("Error sending response");
}
thread::sleep(std::time::Duration::from_millis(5));
}
}