Merge branch 'main' into right-arrow-json

This commit is contained in:
Kacper Madej
2025-01-10 19:25:56 +07:00
11 changed files with 215 additions and 88 deletions

38
.github/workflows/java.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: Java Tests
on:
push:
branches:
- main
tags:
- v*
pull_request:
branches:
- main
env:
working-directory: bindings/java
jobs:
test:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ${{ env.working-directory }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Rust(stable)
uses: dtolnay/rust-toolchain@stable
- name: Set up JDK
uses: actions/setup-java@v3
with:
distribution: 'temurin'
java-version: '11'
- name: Run Java tests
run: make test

View File

@@ -1,6 +1,6 @@
.PHONY: lib
.PHONY: test build_test
run_test: build_test
test: build_test
./gradlew test
build_test:

View File

@@ -50,18 +50,34 @@ pub extern "system" fn Java_org_github_tursodatabase_core_LimboDB__1open_1utf8<'
Box::into_raw(Box::new(db)) as jlong
}
#[no_mangle]
pub extern "system" fn Java_org_github_tursodatabase_core_LimboDB_throwJavaException<'local>(
mut env: JNIEnv<'local>,
obj: JObject<'local>,
error_code: jint,
) {
set_err_msg_and_throw_exception(
&mut env,
obj,
error_code,
"throw java exception".to_string(),
);
}
fn set_err_msg_and_throw_exception<'local>(
env: &mut JNIEnv<'local>,
obj: JObject<'local>,
err_code: i32,
err_msg: String,
) {
let error_message_pointer = Box::into_raw(Box::new(err_msg)) as i64;
let error_message_bytes = env
.byte_array_from_slice(err_msg.as_bytes())
.expect("Failed to convert to byte array");
match env.call_method(
obj,
"newSQLException",
"(IJ)Lorg/github/tursodatabase/exceptions/LimboException;",
&[err_code.into(), error_message_pointer.into()],
"throwLimboException",
"(I[B)V",
&[err_code.into(), (&error_message_bytes).into()],
) {
Ok(_) => {
// do nothing because above method will always return Err
@@ -71,16 +87,3 @@ fn set_err_msg_and_throw_exception<'local>(
}
}
}
#[no_mangle]
pub unsafe extern "system" fn Java_org_github_tursodatabase_core_LimboDB_getErrorMessageUtf8<
'local,
>(
env: JNIEnv<'local>,
_obj: JObject<'local>,
error_message_ptr: jlong,
) -> JByteArray<'local> {
let error_message = Box::from_raw(error_message_ptr as *mut String);
let error_message_bytes = error_message.as_bytes();
env.byte_array_from_slice(error_message_bytes).unwrap()
}

View File

@@ -9,7 +9,7 @@ import java.lang.annotation.Target;
/**
* Annotation to mark methods that are called by native functions.
*/
@Retention(RetentionPolicy.RUNTIME)
@Retention(RetentionPolicy.SOURCE)
@Target(ElementType.METHOD)
public @interface NativeInvocation {
}

View File

@@ -0,0 +1,14 @@
package org.github.tursodatabase;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Annotation to mark methods that use larger visibility for testing purposes.
*/
@Retention(RetentionPolicy.SOURCE)
@Target(ElementType.METHOD)
public @interface VisibleForTesting {
}

View File

@@ -172,35 +172,4 @@ public abstract class AbstractDB {
// TODO: add implementation
throw new SQLFeatureNotSupportedException();
}
/**
* Throws SQL Exception with error code.
*
* @param errorCode Error code to be passed.
* @throws SQLException Formatted SQLException with error code
*/
@NativeInvocation
private LimboException newSQLException(int errorCode, long errorMessagePointer) throws SQLException {
throw newSQLException(errorCode, getErrorMessage(errorMessagePointer));
}
/**
* Throws formatted SQLException with error code and message.
*
* @param errorCode Error code to be passed.
* @param errorMessage throw newSQLException(errorCode);Error message to be passed.
* @return Formatted SQLException with error code and message.
*/
public static LimboException newSQLException(int errorCode, String errorMessage) {
LimboErrorCode code = LimboErrorCode.getErrorCode(errorCode);
String msg;
if (code == LimboErrorCode.UNKNOWN_ERROR) {
msg = String.format("%s:%s (%s)", code, errorCode, errorMessage);
} else {
msg = String.format("%s (%s)", code, errorMessage);
}
return new LimboException(msg, code);
}
protected abstract String getErrorMessage(long errorMessagePointer);
}

View File

@@ -2,6 +2,9 @@ package org.github.tursodatabase.core;
import org.github.tursodatabase.LimboErrorCode;
import org.github.tursodatabase.NativeInvocation;
import org.github.tursodatabase.VisibleForTesting;
import org.github.tursodatabase.exceptions.LimboException;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
@@ -30,8 +33,7 @@ public final class LimboDB extends AbstractDB {
// url example: "jdbc:sqlite:{fileName}
/**
*
* @param url e.g. "jdbc:sqlite:fileName
* @param url e.g. "jdbc:sqlite:fileName
* @param fileName e.g. path to file
*/
public static LimboDB create(String url, String fileName) throws SQLException {
@@ -83,7 +85,7 @@ public final class LimboDB extends AbstractDB {
@Override
protected void _open(String fileName, int openFlags) throws SQLException {
if (isOpen) {
throw newSQLException(LimboErrorCode.UNKNOWN_ERROR.code, "Already opened");
throwLimboException(LimboErrorCode.UNKNOWN_ERROR.code, "Already opened");
}
dbPtr = _open_utf8(stringToUtf8ByteArray(fileName), openFlags);
isOpen = true;
@@ -103,12 +105,38 @@ public final class LimboDB extends AbstractDB {
@Override
public synchronized native int step(long stmt);
@Override
protected String getErrorMessage(long errorMessagePointer) {
return utf8ByteBufferToString(getErrorMessageUtf8(errorMessagePointer));
@VisibleForTesting
native void throwJavaException(int errorCode) throws SQLException;
/**
* Throws formatted SQLException with error code and message.
*
* @param errorCode Error code.
* @param errorMessageBytes Error message.
*/
@NativeInvocation
private void throwLimboException(int errorCode, byte[] errorMessageBytes) throws SQLException {
String errorMessage = utf8ByteBufferToString(errorMessageBytes);
throwLimboException(errorCode, errorMessage);
}
private native byte[] getErrorMessageUtf8(long errorMessagePointer);
/**
* Throws formatted SQLException with error code and message.
*
* @param errorCode Error code.
* @param errorMessage Error message.
*/
public void throwLimboException(int errorCode, String errorMessage) throws SQLException {
LimboErrorCode code = LimboErrorCode.getErrorCode(errorCode);
String msg;
if (code == LimboErrorCode.UNKNOWN_ERROR) {
msg = String.format("%s:%s (%s)", code, errorCode, errorMessage);
} else {
msg = String.format("%s (%s)", code, errorMessage);
}
throw new LimboException(msg, code);
}
private static String utf8ByteBufferToString(byte[] buffer) {
if (buffer == null) {

View File

@@ -1,10 +1,13 @@
package org.github.tursodatabase.core;
import org.github.tursodatabase.LimboErrorCode;
import org.github.tursodatabase.TestUtils;
import org.github.tursodatabase.exceptions.LimboException;
import org.junit.jupiter.api.Test;
import java.sql.SQLException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class LimboDBTest {
@@ -26,4 +29,20 @@ public class LimboDBTest {
assertThatThrownBy(() -> db.open(0)).isInstanceOf(SQLException.class);
}
@Test
void throwJavaException_should_throw_appropriate_java_exception() throws Exception {
String dbPath = TestUtils.createTempFile();
LimboDB db = LimboDB.create("jdbc:sqlite:" + dbPath, dbPath);
db.load();
final int limboExceptionCode = LimboErrorCode.ETC.code;
try {
db.throwJavaException(limboExceptionCode);
} catch (Exception e) {
assertThat(e).isInstanceOf(LimboException.class);
LimboException limboException = (LimboException) e;
assertThat(limboException.getResultCode().code).isEqualTo(limboExceptionCode);
}
}
}

View File

@@ -24,7 +24,7 @@ pub fn optimize_plan(plan: &mut Plan) -> Result<()> {
*/
fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> {
optimize_subqueries(&mut plan.source)?;
rewrite_exprs(&mut plan.source, &mut plan.where_clause)?;
rewrite_exprs_select(plan)?;
if let ConstantConditionEliminationResult::ImpossibleCondition =
eliminate_constants(&mut plan.source, &mut plan.where_clause)?
{
@@ -55,7 +55,7 @@ fn optimize_select_plan(plan: &mut SelectPlan) -> Result<()> {
}
fn optimize_delete_plan(plan: &mut DeletePlan) -> Result<()> {
rewrite_exprs(&mut plan.source, &mut plan.where_clause)?;
rewrite_exprs_delete(plan)?;
if let ConstantConditionEliminationResult::ImpossibleCondition =
eliminate_constants(&mut plan.source, &mut plan.where_clause)?
{
@@ -603,16 +603,45 @@ fn push_scan_direction(operator: &mut SourceOperator, direction: &Direction) {
}
}
fn rewrite_exprs(
operator: &mut SourceOperator,
where_clauses: &mut Option<Vec<ast::Expr>>,
) -> Result<()> {
if let Some(predicates) = where_clauses {
for expr in predicates.iter_mut() {
fn rewrite_exprs_select(plan: &mut SelectPlan) -> Result<()> {
rewrite_source_operator_exprs(&mut plan.source)?;
for rc in plan.result_columns.iter_mut() {
rewrite_expr(&mut rc.expr)?;
}
for agg in plan.aggregates.iter_mut() {
rewrite_expr(&mut agg.original_expr)?;
}
if let Some(predicates) = &mut plan.where_clause {
for expr in predicates {
rewrite_expr(expr)?;
}
}
if let Some(group_by) = &mut plan.group_by {
for expr in group_by.exprs.iter_mut() {
rewrite_expr(expr)?;
}
}
if let Some(order_by) = &mut plan.order_by {
for (expr, _) in order_by.iter_mut() {
rewrite_expr(expr)?;
}
}
Ok(())
}
fn rewrite_exprs_delete(plan: &mut DeletePlan) -> Result<()> {
rewrite_source_operator_exprs(&mut plan.source)?;
if let Some(predicates) = &mut plan.where_clause {
for expr in predicates {
rewrite_expr(expr)?;
}
}
Ok(())
}
fn rewrite_source_operator_exprs(operator: &mut SourceOperator) -> Result<()> {
match operator {
SourceOperator::Join {
left,
@@ -620,35 +649,37 @@ fn rewrite_exprs(
predicates,
..
} => {
rewrite_exprs(left, where_clauses)?;
rewrite_exprs(right, where_clauses)?;
rewrite_source_operator_exprs(left)?;
rewrite_source_operator_exprs(right)?;
if let Some(predicates) = predicates {
for expr in predicates.iter_mut() {
rewrite_expr(expr)?;
}
}
}
SourceOperator::Scan {
predicates: Some(preds),
..
} => {
for expr in preds.iter_mut() {
rewrite_expr(expr)?;
}
}
SourceOperator::Search {
predicates: Some(preds),
..
} => {
for expr in preds.iter_mut() {
rewrite_expr(expr)?;
}
}
_ => (),
}
Ok(())
Ok(())
}
SourceOperator::Scan { predicates, .. } | SourceOperator::Search { predicates, .. } => {
if let Some(predicates) = predicates {
for expr in predicates.iter_mut() {
rewrite_expr(expr)?;
}
}
Ok(())
}
SourceOperator::Subquery { predicates, .. } => {
if let Some(predicates) = predicates {
for expr in predicates.iter_mut() {
rewrite_expr(expr)?;
}
}
Ok(())
}
SourceOperator::Nothing { .. } => Ok(()),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

View File

@@ -544,6 +544,14 @@ fn parse_join(
pub fn parse_limit(limit: Limit) -> Option<usize> {
if let Expr::Literal(ast::Literal::Numeric(n)) = limit.expr {
n.parse().ok()
} else if let Expr::Id(id) = limit.expr {
if id.0.eq_ignore_ascii_case("true") {
Some(1)
} else if id.0.eq_ignore_ascii_case("false") {
Some(0)
} else {
None
}
} else {
None
}

View File

@@ -11,6 +11,14 @@ do_execsql_test select-const-2 {
SELECT 2
} {2}
do_execsql_test select-true {
SELECT true
} {1}
do_execsql_test select-false {
SELECT false
} {0}
do_execsql_test select-text-escape-1 {
SELECT '''a'
} {'a}
@@ -31,6 +39,15 @@ do_execsql_test select-limit-0 {
SELECT id FROM users LIMIT 0;
} {}
# ORDER BY id here because sqlite uses age_idx here and we (yet) don't so force it to evaluate in ID order
do_execsql_test select-limit-true {
SELECT id FROM users ORDER BY id LIMIT true;
} {1}
do_execsql_test select-limit-false {
SELECT id FROM users ORDER BY id LIMIT false;
} {}
do_execsql_test realify {
select price from products limit 1;
} {79.0}