diff --git a/editors/vscode/package.json b/editors/vscode/package.json index 8c53371..0121842 100644 --- a/editors/vscode/package.json +++ b/editors/vscode/package.json @@ -12,6 +12,14 @@ "engines": { "vscode": "^1.75.0" }, + "contributes": { + "commands": [ + { + "command": "lsp-ai.generate", + "title": "LSP AI Generate" + } + ] + }, "devDependencies": { "@types/node": "^20.11.0", "typescript": "^5.3.3" diff --git a/editors/vscode/src/index.ts b/editors/vscode/src/index.ts index 06b7a68..2790037 100644 --- a/editors/vscode/src/index.ts +++ b/editors/vscode/src/index.ts @@ -1,5 +1,4 @@ -import { workspace, ExtensionContext } from 'vscode'; - +import * as vscode from 'vscode'; import { LanguageClient, LanguageClientOptions, @@ -9,9 +8,7 @@ import { let client: LanguageClient; -export function activate(_context: ExtensionContext) { - console.log("\n\nIN THE ACTIVATE FUNCTION\n\n"); - +export function activate(context: vscode.ExtensionContext) { // Configure the server options let serverOptions: ServerOptions = { command: "lsp-ai", @@ -20,11 +17,7 @@ export function activate(_context: ExtensionContext) { // Options to control the language client let clientOptions: LanguageClientOptions = { - documentSelector: [{ scheme: 'file', language: 'python' }], - synchronize: { - // Notify the server about file changes to '.clientrc files contained in the workspace - fileEvents: workspace.createFileSystemWatcher('**/.clientrc') - } + documentSelector: [{ pattern: "**" }], }; // Create the language client and start the client @@ -35,10 +28,34 @@ export function activate(_context: ExtensionContext) { clientOptions ); - console.log("\n\nSTARTING THE CLIENT\n\n"); - // Start the client. This will also launch the server client.start(); + + client.onRequest("textDocument/completion", (params) => { + console.log("HERE WE GO"); + console.log(params); + }); + + // Register functions + const command = 'lsp-ai.generate'; + const commandHandler = () => { + const editor = vscode.window.activeTextEditor; + console.log("SENDING REQUEST FOR GENERATE"); + console.log(editor); + let params = { + textDocument: { + uri: editor.document.uri.toString(), + }, + position: editor.selection.active + }; + console.log(params); + client.sendRequest("textDocument/generate", params).then(result => { + console.log(result); + }).catch(error => { + console.error(error); + }); + }; + context.subscriptions.push(vscode.commands.registerCommand(command, commandHandler)); } export function deactivate(): Thenable | undefined { diff --git a/src/custom_requests/generate.rs b/src/custom_requests/generate.rs new file mode 100644 index 0000000..b163965 --- /dev/null +++ b/src/custom_requests/generate.rs @@ -0,0 +1,23 @@ +use lsp_types::TextDocumentPositionParams; +use serde::{Deserialize, Serialize}; + +pub enum Generate {} + +#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateParams { + // This field was "mixed-in" from TextDocumentPositionParams + #[serde(flatten)] + pub text_document_position: TextDocumentPositionParams, +} + +#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +pub struct GenerateResult { + pub generated_text: String, +} + +impl lsp_types::request::Request for Generate { + type Params = GenerateParams; + type Result = GenerateResult; + const METHOD: &'static str = "textDocument/generate"; +} diff --git a/src/custom_requests/mod.rs b/src/custom_requests/mod.rs new file mode 100644 index 0000000..57e6a9d --- /dev/null +++ b/src/custom_requests/mod.rs @@ -0,0 +1 @@ +pub mod generate; diff --git a/src/main.rs b/src/main.rs index 41d699c..b186e40 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,21 @@ -use anyhow::Context; -use anyhow::Result; -use core::panic; -use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response}; +use anyhow::{Context, Result}; +use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId}; use lsp_types::{ - request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions, - CompletionParams, CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, - Position, Range, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit, + request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams, + RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, }; -use parking_lot::Mutex; +use once_cell::sync::Lazy; +use parking_lot::RwLock; +use pyo3::prelude::*; use ropey::Rope; use serde::Deserialize; -use std::collections::HashMap; -use std::sync::Arc; -use std::thread; +use std::{collections::HashMap, sync::Arc, thread}; -use once_cell::sync::Lazy; -use pyo3::prelude::*; +mod custom_requests; +mod worker; + +use custom_requests::generate::Generate; +use worker::{CompletionRequest, GenerateRequest, WorkerRequest}; pub static PY_MODULE: Lazy>> = Lazy::new(|| { pyo3::Python::with_gil(|py| -> Result> { @@ -67,23 +67,11 @@ fn main() -> Result<()> { #[derive(Deserialize)] struct Params {} -struct CompletionRequest { - id: RequestId, - params: CompletionParams, - rope: Rope, -} - -impl CompletionRequest { - fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self { - Self { id, params, rope } - } -} - // This main loop is tricky // We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get // Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { - let params: Params = serde_json::from_value(params)?; + let _params: Params = serde_json::from_value(params)?; // Set the model Python::with_gil(|py| -> Result<()> { @@ -100,98 +88,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { let mut file_map: HashMap = HashMap::new(); // How we communicate between the worker and receiver threads - let last_completion_request = Arc::new(Mutex::new(None)); + let last_worker_request = Arc::new(RwLock::new(None)); // Thread local variables - let thread_last_completion_request = last_completion_request.clone(); + let thread_last_worker_request = last_worker_request.clone(); let thread_connection = connection.clone(); thread::spawn(move || { - loop { - // I think we need this drop, not 100% sure though - let mut completion_request = thread_last_completion_request.lock(); - let params = std::mem::take(&mut *completion_request); - drop(completion_request); - if let Some(CompletionRequest { - id, - params, - mut rope, - }) = params - { - let filter_text = rope - .get_line(params.text_document_position.position.line as usize) - .expect("Error getting line with ropey") - .to_string(); - - // Convert rope to correct prompt for llm - let cursor_index = rope - .line_to_char(params.text_document_position.position.line as usize) - + params.text_document_position.position.character as usize; - - // We will want to have some kind of infill support we add - // rope.insert(cursor_index, "<|fim_hole|>"); - // rope.insert(0, "<|fim_start|>"); - // rope.insert(rope.len_chars(), "<|fim_end|>"); - // let prompt = rope.to_string(); - - let prompt = rope - .get_slice((0..cursor_index)) - .expect("Error getting rope slice") - .to_string(); - - eprintln!("\n\n****{prompt}****\n\n"); - - let insert_text = Python::with_gil(|py| -> Result { - let transform: Py = PY_MODULE - .as_ref() - .map_err(anyhow::Error::msg)? - .getattr(py, "transform")?; - - let out: String = transform.call1(py, (prompt,))?.extract(py)?; - Ok(out) - }) - .expect("Error during transform"); - - eprintln!("\n{insert_text}\n"); - - // Create and return the completion - let completion_text_edit = TextEdit::new( - Range::new( - Position::new( - params.text_document_position.position.line, - params.text_document_position.position.character, - ), - Position::new( - params.text_document_position.position.line, - params.text_document_position.position.character, - ), - ), - insert_text.clone(), - ); - let item = CompletionItem { - label: format!("ai - {insert_text}"), - filter_text: Some(filter_text), - text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)), - kind: Some(CompletionItemKind::TEXT), - ..Default::default() - }; - let completion_list = CompletionList { - is_incomplete: false, - items: vec![item], - }; - let result = Some(CompletionResponse::List(completion_list)); - let result = serde_json::to_value(&result).unwrap(); - let resp = Response { - id, - result: Some(result), - error: None, - }; - thread_connection - .sender - .send(Message::Response(resp)) - .expect("Error sending response"); - } - thread::sleep(std::time::Duration::from_millis(5)); - } + worker::run(thread_last_worker_request, thread_connection); }); for msg in &connection.receiver { @@ -200,21 +103,44 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { if connection.handle_shutdown(&req)? { return Ok(()); } - match cast::(req) { - Ok((id, params)) => { - // Get rope - let rope = file_map - .get(params.text_document_position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); - // Update the last CompletionRequest - let mut lcr = last_completion_request.lock(); - *lcr = Some(CompletionRequest::new(id, params, rope)); - continue; + eprintln!( + "NEW REQUEST: \n{}\n", + serde_json::to_string_pretty(&req).unwrap() + ); + + match req.method.as_str() { + "textDocument/completion" => match cast::(req) { + Ok((id, params)) => { + // Get rope + let rope = file_map + .get(params.text_document_position.text_document.uri.as_str()) + .context("Error file not found")? + .clone(); + // Update the last CompletionRequest + let mut lcr = last_worker_request.write(); + let completion_request = CompletionRequest::new(id, params, rope); + *lcr = Some(WorkerRequest::Completion(completion_request)); + } + Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"), + Err(ExtractError::MethodMismatch(_req)) => (), + }, + "textDocument/generate" => match cast::(req) { + Ok((id, params)) => { + // Get rope + let rope = file_map + .get(params.text_document_position.text_document.uri.as_str()) + .context("Error file not found")? + .clone(); + // Update the last CompletionRequest + let mut lcr = last_worker_request.write(); + let completion_request = GenerateRequest::new(id, params, rope); + *lcr = Some(WorkerRequest::Generate(completion_request)); + } + Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"), + Err(ExtractError::MethodMismatch(_req)) => (), } - Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"), - Err(ExtractError::MethodMismatch(req)) => req, - }; + _ => eprintln!("lsp-ai currently only supports textDocument/completion and textDocument/generate") + } } Message::Notification(not) => { if notification_is::(¬) { diff --git a/src/worker/completion.rs b/src/worker/completion.rs new file mode 100644 index 0000000..58de2f7 --- /dev/null +++ b/src/worker/completion.rs @@ -0,0 +1,61 @@ +use lsp_server::ResponseError; +use pyo3::prelude::*; + +use super::CompletionRequest; +use crate::PY_MODULE; + +pub struct DoCompletionResponse { + pub insert_text: String, + pub filter_text: String, +} + +pub fn do_completion(request: &CompletionRequest) -> Result { + let filter_text = request + .rope + .get_line(request.params.text_document_position.position.line as usize) + .ok_or(ResponseError { + code: -32603, // Maybe we want a different error code here? + message: "Error getting line in requested document".to_string(), + data: None, + })? + .to_string(); + + // Convert rope to correct prompt for llm + let cursor_index = request + .rope + .line_to_char(request.params.text_document_position.position.line as usize) + + request.params.text_document_position.position.character as usize; + + // We will want to have some kind of infill support we add + // rope.insert(cursor_index, "<|fim_hole|>"); + // rope.insert(0, "<|fim_start|>"); + // rope.insert(rope.len_chars(), "<|fim_end|>"); + // let prompt = rope.to_string(); + + let prompt = request + .rope + .get_slice(0..cursor_index) + .expect("Error getting rope slice") + .to_string(); + + eprintln!("\n\n****{prompt}****\n\n"); + + Python::with_gil(|py| -> anyhow::Result { + let transform: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "transform")?; + + let out: String = transform.call1(py, (prompt,))?.extract(py)?; + Ok(out) + }) + .map(|insert_text| DoCompletionResponse { + insert_text, + filter_text, + }) + .map_err(|e| ResponseError { + code: -32603, + message: e.to_string(), + data: None, + }) +} diff --git a/src/worker/generate.rs b/src/worker/generate.rs new file mode 100644 index 0000000..89bce44 --- /dev/null +++ b/src/worker/generate.rs @@ -0,0 +1,47 @@ +use lsp_server::ResponseError; +use pyo3::prelude::*; + +use super::GenerateRequest; +use crate::PY_MODULE; + +pub struct DoGenerateResponse { + pub generated_text: String, +} + +pub fn do_generate(request: &GenerateRequest) -> Result { + // Convert rope to correct prompt for llm + let cursor_index = request + .rope + .line_to_char(request.params.text_document_position.position.line as usize) + + request.params.text_document_position.position.character as usize; + + // We will want to have some kind of infill support we add + // rope.insert(cursor_index, "<|fim_hole|>"); + // rope.insert(0, "<|fim_start|>"); + // rope.insert(rope.len_chars(), "<|fim_end|>"); + // let prompt = rope.to_string(); + + let prompt = request + .rope + .get_slice(0..cursor_index) + .expect("Error getting rope slice") + .to_string(); + + eprintln!("\n\n****{prompt}****\n\n"); + + Python::with_gil(|py| -> anyhow::Result { + let transform: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "transform")?; + + let out: String = transform.call1(py, (prompt,))?.extract(py)?; + Ok(out) + }) + .map(|generated_text| DoGenerateResponse { generated_text }) + .map_err(|e| ResponseError { + code: -32603, + message: e.to_string(), + data: None, + }) +} diff --git a/src/worker/mod.rs b/src/worker/mod.rs new file mode 100644 index 0000000..4d0d52b --- /dev/null +++ b/src/worker/mod.rs @@ -0,0 +1,125 @@ +use lsp_server::{Connection, Message, RequestId, Response}; +use lsp_types::{ + CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse, + Position, Range, TextEdit, +}; +use parking_lot::RwLock; +use ropey::Rope; +use std::{sync::Arc, thread}; + +mod completion; +mod generate; + +use crate::custom_requests::generate::{GenerateParams, GenerateResult}; +use completion::do_completion; +use generate::do_generate; + +#[derive(Clone)] +pub struct CompletionRequest { + id: RequestId, + params: CompletionParams, + rope: Rope, +} + +impl CompletionRequest { + pub fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self { + Self { id, params, rope } + } +} + +#[derive(Clone)] +pub struct GenerateRequest { + id: RequestId, + params: GenerateParams, + rope: Rope, +} + +impl GenerateRequest { + pub fn new(id: RequestId, params: GenerateParams, rope: Rope) -> Self { + Self { id, params, rope } + } +} + +#[derive(Clone)] +pub enum WorkerRequest { + Completion(CompletionRequest), + Generate(GenerateRequest), +} + +pub fn run(last_worker_request: Arc>>, connection: Arc) { + loop { + let option_worker_request: Option = { + let completion_request = last_worker_request.read(); + (*completion_request).clone() + }; + if let Some(request) = option_worker_request { + let response = match request { + WorkerRequest::Completion(request) => match do_completion(&request) { + Ok(response) => { + let completion_text_edit = TextEdit::new( + Range::new( + Position::new( + request.params.text_document_position.position.line, + request.params.text_document_position.position.character, + ), + Position::new( + request.params.text_document_position.position.line, + request.params.text_document_position.position.character, + ), + ), + response.insert_text.clone(), + ); + let item = CompletionItem { + label: format!("ai - {}", response.insert_text), + filter_text: Some(response.filter_text), + text_edit: Some(lsp_types::CompletionTextEdit::Edit( + completion_text_edit, + )), + kind: Some(CompletionItemKind::TEXT), + ..Default::default() + }; + let completion_list = CompletionList { + is_incomplete: false, + items: vec![item], + }; + let result = Some(CompletionResponse::List(completion_list)); + let result = serde_json::to_value(&result).unwrap(); + Response { + id: request.id, + result: Some(result), + error: None, + } + } + Err(e) => Response { + id: request.id, + result: None, + error: Some(e), + }, + }, + WorkerRequest::Generate(request) => match do_generate(&request) { + Ok(result) => { + let result = GenerateResult { + generated_text: result.generated_text, + }; + let result = serde_json::to_value(&result).unwrap(); + Response { + id: request.id, + result: Some(result), + error: None, + } + } + Err(e) => Response { + id: request.id, + result: None, + error: Some(e), + }, + }, + }; + connection + .sender + .send(Message::Response(response)) + .expect("Error sending response"); + } + thread::sleep(std::time::Duration::from_millis(5)); + } +}