mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2025-12-18 15:04:29 +01:00
Reorganized
This commit is contained in:
@@ -12,6 +12,14 @@
|
||||
"engines": {
|
||||
"vscode": "^1.75.0"
|
||||
},
|
||||
"contributes": {
|
||||
"commands": [
|
||||
{
|
||||
"command": "lsp-ai.generate",
|
||||
"title": "LSP AI Generate"
|
||||
}
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.0",
|
||||
"typescript": "^5.3.3"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { workspace, ExtensionContext } from 'vscode';
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import {
|
||||
LanguageClient,
|
||||
LanguageClientOptions,
|
||||
@@ -9,9 +8,7 @@ import {
|
||||
|
||||
let client: LanguageClient;
|
||||
|
||||
export function activate(_context: ExtensionContext) {
|
||||
console.log("\n\nIN THE ACTIVATE FUNCTION\n\n");
|
||||
|
||||
export function activate(context: vscode.ExtensionContext) {
|
||||
// Configure the server options
|
||||
let serverOptions: ServerOptions = {
|
||||
command: "lsp-ai",
|
||||
@@ -20,11 +17,7 @@ export function activate(_context: ExtensionContext) {
|
||||
|
||||
// Options to control the language client
|
||||
let clientOptions: LanguageClientOptions = {
|
||||
documentSelector: [{ scheme: 'file', language: 'python' }],
|
||||
synchronize: {
|
||||
// Notify the server about file changes to '.clientrc files contained in the workspace
|
||||
fileEvents: workspace.createFileSystemWatcher('**/.clientrc')
|
||||
}
|
||||
documentSelector: [{ pattern: "**" }],
|
||||
};
|
||||
|
||||
// Create the language client and start the client
|
||||
@@ -35,10 +28,34 @@ export function activate(_context: ExtensionContext) {
|
||||
clientOptions
|
||||
);
|
||||
|
||||
console.log("\n\nSTARTING THE CLIENT\n\n");
|
||||
|
||||
// Start the client. This will also launch the server
|
||||
client.start();
|
||||
|
||||
client.onRequest("textDocument/completion", (params) => {
|
||||
console.log("HERE WE GO");
|
||||
console.log(params);
|
||||
});
|
||||
|
||||
// Register functions
|
||||
const command = 'lsp-ai.generate';
|
||||
const commandHandler = () => {
|
||||
const editor = vscode.window.activeTextEditor;
|
||||
console.log("SENDING REQUEST FOR GENERATE");
|
||||
console.log(editor);
|
||||
let params = {
|
||||
textDocument: {
|
||||
uri: editor.document.uri.toString(),
|
||||
},
|
||||
position: editor.selection.active
|
||||
};
|
||||
console.log(params);
|
||||
client.sendRequest("textDocument/generate", params).then(result => {
|
||||
console.log(result);
|
||||
}).catch(error => {
|
||||
console.error(error);
|
||||
});
|
||||
};
|
||||
context.subscriptions.push(vscode.commands.registerCommand(command, commandHandler));
|
||||
}
|
||||
|
||||
export function deactivate(): Thenable<void> | undefined {
|
||||
|
||||
23
src/custom_requests/generate.rs
Normal file
23
src/custom_requests/generate.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub enum Generate {}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateParams {
|
||||
// This field was "mixed-in" from TextDocumentPositionParams
|
||||
#[serde(flatten)]
|
||||
pub text_document_position: TextDocumentPositionParams,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
||||
pub struct GenerateResult {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
impl lsp_types::request::Request for Generate {
|
||||
type Params = GenerateParams;
|
||||
type Result = GenerateResult;
|
||||
const METHOD: &'static str = "textDocument/generate";
|
||||
}
|
||||
1
src/custom_requests/mod.rs
Normal file
1
src/custom_requests/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod generate;
|
||||
182
src/main.rs
182
src/main.rs
@@ -1,21 +1,21 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use core::panic;
|
||||
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
|
||||
use anyhow::{Context, Result};
|
||||
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId};
|
||||
use lsp_types::{
|
||||
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
|
||||
CompletionParams, CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
|
||||
Position, Range, RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
|
||||
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
|
||||
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::RwLock;
|
||||
use pyo3::prelude::*;
|
||||
use ropey::Rope;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::{collections::HashMap, sync::Arc, thread};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use pyo3::prelude::*;
|
||||
mod custom_requests;
|
||||
mod worker;
|
||||
|
||||
use custom_requests::generate::Generate;
|
||||
use worker::{CompletionRequest, GenerateRequest, WorkerRequest};
|
||||
|
||||
pub static PY_MODULE: Lazy<Result<Py<PyAny>>> = Lazy::new(|| {
|
||||
pyo3::Python::with_gil(|py| -> Result<Py<PyAny>> {
|
||||
@@ -67,23 +67,11 @@ fn main() -> Result<()> {
|
||||
#[derive(Deserialize)]
|
||||
struct Params {}
|
||||
|
||||
struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
// This main loop is tricky
|
||||
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
|
||||
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
|
||||
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
let params: Params = serde_json::from_value(params)?;
|
||||
let _params: Params = serde_json::from_value(params)?;
|
||||
|
||||
// Set the model
|
||||
Python::with_gil(|py| -> Result<()> {
|
||||
@@ -100,98 +88,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
let mut file_map: HashMap<String, Rope> = HashMap::new();
|
||||
|
||||
// How we communicate between the worker and receiver threads
|
||||
let last_completion_request = Arc::new(Mutex::new(None));
|
||||
let last_worker_request = Arc::new(RwLock::new(None));
|
||||
|
||||
// Thread local variables
|
||||
let thread_last_completion_request = last_completion_request.clone();
|
||||
let thread_last_worker_request = last_worker_request.clone();
|
||||
let thread_connection = connection.clone();
|
||||
thread::spawn(move || {
|
||||
loop {
|
||||
// I think we need this drop, not 100% sure though
|
||||
let mut completion_request = thread_last_completion_request.lock();
|
||||
let params = std::mem::take(&mut *completion_request);
|
||||
drop(completion_request);
|
||||
if let Some(CompletionRequest {
|
||||
id,
|
||||
params,
|
||||
mut rope,
|
||||
}) = params
|
||||
{
|
||||
let filter_text = rope
|
||||
.get_line(params.text_document_position.position.line as usize)
|
||||
.expect("Error getting line with ropey")
|
||||
.to_string();
|
||||
|
||||
// Convert rope to correct prompt for llm
|
||||
let cursor_index = rope
|
||||
.line_to_char(params.text_document_position.position.line as usize)
|
||||
+ params.text_document_position.position.character as usize;
|
||||
|
||||
// We will want to have some kind of infill support we add
|
||||
// rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// rope.insert(0, "<|fim_start|>");
|
||||
// rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// let prompt = rope.to_string();
|
||||
|
||||
let prompt = rope
|
||||
.get_slice((0..cursor_index))
|
||||
.expect("Error getting rope slice")
|
||||
.to_string();
|
||||
|
||||
eprintln!("\n\n****{prompt}****\n\n");
|
||||
|
||||
let insert_text = Python::with_gil(|py| -> Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.expect("Error during transform");
|
||||
|
||||
eprintln!("\n{insert_text}\n");
|
||||
|
||||
// Create and return the completion
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {insert_text}"),
|
||||
filter_text: Some(filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
let resp = Response {
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
};
|
||||
thread_connection
|
||||
.sender
|
||||
.send(Message::Response(resp))
|
||||
.expect("Error sending response");
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
worker::run(thread_last_worker_request, thread_connection);
|
||||
});
|
||||
|
||||
for msg in &connection.receiver {
|
||||
@@ -200,21 +103,44 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
if connection.handle_shutdown(&req)? {
|
||||
return Ok(());
|
||||
}
|
||||
match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
// Get rope
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
// Update the last CompletionRequest
|
||||
let mut lcr = last_completion_request.lock();
|
||||
*lcr = Some(CompletionRequest::new(id, params, rope));
|
||||
continue;
|
||||
eprintln!(
|
||||
"NEW REQUEST: \n{}\n",
|
||||
serde_json::to_string_pretty(&req).unwrap()
|
||||
);
|
||||
|
||||
match req.method.as_str() {
|
||||
"textDocument/completion" => match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
// Get rope
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
// Update the last CompletionRequest
|
||||
let mut lcr = last_worker_request.write();
|
||||
let completion_request = CompletionRequest::new(id, params, rope);
|
||||
*lcr = Some(WorkerRequest::Completion(completion_request));
|
||||
}
|
||||
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
|
||||
Err(ExtractError::MethodMismatch(_req)) => (),
|
||||
},
|
||||
"textDocument/generate" => match cast::<Generate>(req) {
|
||||
Ok((id, params)) => {
|
||||
// Get rope
|
||||
let rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
// Update the last CompletionRequest
|
||||
let mut lcr = last_worker_request.write();
|
||||
let completion_request = GenerateRequest::new(id, params, rope);
|
||||
*lcr = Some(WorkerRequest::Generate(completion_request));
|
||||
}
|
||||
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
|
||||
Err(ExtractError::MethodMismatch(_req)) => (),
|
||||
}
|
||||
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
|
||||
Err(ExtractError::MethodMismatch(req)) => req,
|
||||
};
|
||||
_ => eprintln!("lsp-ai currently only supports textDocument/completion and textDocument/generate")
|
||||
}
|
||||
}
|
||||
Message::Notification(not) => {
|
||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||
|
||||
61
src/worker/completion.rs
Normal file
61
src/worker/completion.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use lsp_server::ResponseError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use super::CompletionRequest;
|
||||
use crate::PY_MODULE;
|
||||
|
||||
pub struct DoCompletionResponse {
|
||||
pub insert_text: String,
|
||||
pub filter_text: String,
|
||||
}
|
||||
|
||||
pub fn do_completion(request: &CompletionRequest) -> Result<DoCompletionResponse, ResponseError> {
|
||||
let filter_text = request
|
||||
.rope
|
||||
.get_line(request.params.text_document_position.position.line as usize)
|
||||
.ok_or(ResponseError {
|
||||
code: -32603, // Maybe we want a different error code here?
|
||||
message: "Error getting line in requested document".to_string(),
|
||||
data: None,
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
// Convert rope to correct prompt for llm
|
||||
let cursor_index = request
|
||||
.rope
|
||||
.line_to_char(request.params.text_document_position.position.line as usize)
|
||||
+ request.params.text_document_position.position.character as usize;
|
||||
|
||||
// We will want to have some kind of infill support we add
|
||||
// rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// rope.insert(0, "<|fim_start|>");
|
||||
// rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// let prompt = rope.to_string();
|
||||
|
||||
let prompt = request
|
||||
.rope
|
||||
.get_slice(0..cursor_index)
|
||||
.expect("Error getting rope slice")
|
||||
.to_string();
|
||||
|
||||
eprintln!("\n\n****{prompt}****\n\n");
|
||||
|
||||
Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.map(|insert_text| DoCompletionResponse {
|
||||
insert_text,
|
||||
filter_text,
|
||||
})
|
||||
.map_err(|e| ResponseError {
|
||||
code: -32603,
|
||||
message: e.to_string(),
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
47
src/worker/generate.rs
Normal file
47
src/worker/generate.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use lsp_server::ResponseError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use super::GenerateRequest;
|
||||
use crate::PY_MODULE;
|
||||
|
||||
pub struct DoGenerateResponse {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub fn do_generate(request: &GenerateRequest) -> Result<DoGenerateResponse, ResponseError> {
|
||||
// Convert rope to correct prompt for llm
|
||||
let cursor_index = request
|
||||
.rope
|
||||
.line_to_char(request.params.text_document_position.position.line as usize)
|
||||
+ request.params.text_document_position.position.character as usize;
|
||||
|
||||
// We will want to have some kind of infill support we add
|
||||
// rope.insert(cursor_index, "<|fim_hole|>");
|
||||
// rope.insert(0, "<|fim_start|>");
|
||||
// rope.insert(rope.len_chars(), "<|fim_end|>");
|
||||
// let prompt = rope.to_string();
|
||||
|
||||
let prompt = request
|
||||
.rope
|
||||
.get_slice(0..cursor_index)
|
||||
.expect("Error getting rope slice")
|
||||
.to_string();
|
||||
|
||||
eprintln!("\n\n****{prompt}****\n\n");
|
||||
|
||||
Python::with_gil(|py| -> anyhow::Result<String> {
|
||||
let transform: Py<PyAny> = PY_MODULE
|
||||
.as_ref()
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.getattr(py, "transform")?;
|
||||
|
||||
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
|
||||
Ok(out)
|
||||
})
|
||||
.map(|generated_text| DoGenerateResponse { generated_text })
|
||||
.map_err(|e| ResponseError {
|
||||
code: -32603,
|
||||
message: e.to_string(),
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
125
src/worker/mod.rs
Normal file
125
src/worker/mod.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use lsp_server::{Connection, Message, RequestId, Response};
|
||||
use lsp_types::{
|
||||
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
|
||||
Position, Range, TextEdit,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use ropey::Rope;
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
mod completion;
|
||||
mod generate;
|
||||
|
||||
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
|
||||
use completion::do_completion;
|
||||
use generate::do_generate;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionRequest {
|
||||
id: RequestId,
|
||||
params: CompletionParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GenerateRequest {
|
||||
id: RequestId,
|
||||
params: GenerateParams,
|
||||
rope: Rope,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
pub fn new(id: RequestId, params: GenerateParams, rope: Rope) -> Self {
|
||||
Self { id, params, rope }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum WorkerRequest {
|
||||
Completion(CompletionRequest),
|
||||
Generate(GenerateRequest),
|
||||
}
|
||||
|
||||
pub fn run(last_worker_request: Arc<RwLock<Option<WorkerRequest>>>, connection: Arc<Connection>) {
|
||||
loop {
|
||||
let option_worker_request: Option<WorkerRequest> = {
|
||||
let completion_request = last_worker_request.read();
|
||||
(*completion_request).clone()
|
||||
};
|
||||
if let Some(request) = option_worker_request {
|
||||
let response = match request {
|
||||
WorkerRequest::Completion(request) => match do_completion(&request) {
|
||||
Ok(response) => {
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
request.params.text_document_position.position.line,
|
||||
request.params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
response.insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {}", response.insert_text),
|
||||
filter_text: Some(response.filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
|
||||
completion_text_edit,
|
||||
)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Response {
|
||||
id: request.id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e),
|
||||
},
|
||||
},
|
||||
WorkerRequest::Generate(request) => match do_generate(&request) {
|
||||
Ok(result) => {
|
||||
let result = GenerateResult {
|
||||
generated_text: result.generated_text,
|
||||
};
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
Response {
|
||||
id: request.id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
Err(e) => Response {
|
||||
id: request.id,
|
||||
result: None,
|
||||
error: Some(e),
|
||||
},
|
||||
},
|
||||
};
|
||||
connection
|
||||
.sender
|
||||
.send(Message::Response(response))
|
||||
.expect("Error sending response");
|
||||
}
|
||||
thread::sleep(std::time::Duration::from_millis(5));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user