risingwave_jni_core/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(error_generic_member_access)]
16#![feature(once_cell_try)]
17#![feature(type_alias_impl_trait)]
18#![feature(try_blocks)]
19#![feature(used_with_arg)]
20
21pub mod jvm_runtime;
22mod macros;
23mod tracing_slf4j;
24
25use std::backtrace::Backtrace;
26use std::marker::PhantomData;
27use std::ops::{Deref, DerefMut};
28use std::slice::from_raw_parts;
29use std::sync::{LazyLock, OnceLock};
30
31use anyhow::anyhow;
32use bytes::Bytes;
33use cfg_or_panic::cfg_or_panic;
34use chrono::{DateTime, Datelike, Timelike};
35use futures::TryStreamExt;
36use futures::stream::BoxStream;
37use jni::JNIEnv;
38use jni::objects::{
39    AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString,
40    JValueOwned, ReleaseMode,
41};
42use jni::signature::ReturnType;
43use jni::sys::{
44    JNI_FALSE, JNI_TRUE, jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue,
45};
46pub use paste::paste;
47use prost::{DecodeError, Message};
48use risingwave_common::array::{ArrayError, StreamChunk};
49use risingwave_common::hash::VirtualNode;
50use risingwave_common::row::{OwnedRow, Row};
51use risingwave_common::test_prelude::StreamChunkTestExt;
52use risingwave_common::types::{Decimal, ScalarRefImpl};
53use risingwave_common::util::panic::rw_catch_unwind;
54use risingwave_pb::connector_service::{
55    GetEventStreamResponse, SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse,
56    SinkWriterStreamRequest, SinkWriterStreamResponse,
57};
58use risingwave_pb::data::Op;
59use thiserror::Error;
60use thiserror_ext::AsReport;
61use tokio::runtime::Runtime;
62use tokio::sync::mpsc::{Receiver, Sender};
63use tracing_slf4j::*;
64
65/// Enable JVM and Java libraries.
66///
67/// This macro forces this crate to be linked, which registers the JVM builder.
68#[macro_export]
69macro_rules! enable {
70    () => {
71        use risingwave_jni_core as _;
72    };
73}
74
75pub static JAVA_BINDING_ASYNC_RUNTIME: LazyLock<Runtime> =
76    LazyLock::new(|| tokio::runtime::Runtime::new().unwrap());
77
78#[derive(Error, Debug)]
79pub enum BindingError {
80    #[error("JniError {error}")]
81    Jni {
82        #[from]
83        error: jni::errors::Error,
84        backtrace: Backtrace,
85    },
86
87    #[error("StorageError {error}")]
88    Storage {
89        #[from]
90        error: anyhow::Error,
91        backtrace: Backtrace,
92    },
93
94    #[error("DecodeError {error}")]
95    Decode {
96        #[from]
97        error: DecodeError,
98        backtrace: Backtrace,
99    },
100
101    #[error("StreamChunkArrayError {error}")]
102    StreamChunkArray {
103        #[from]
104        error: ArrayError,
105        backtrace: Backtrace,
106    },
107}
108
109type Result<T> = std::result::Result<T, BindingError>;
110
111pub fn to_guarded_slice<'array, 'env>(
112    array: &'array JByteArray<'env>,
113    env: &'array mut JNIEnv<'env>,
114) -> Result<SliceGuard<'env, 'array>> {
115    unsafe {
116        let array = env.get_array_elements(array, ReleaseMode::NoCopyBack)?;
117        let slice = from_raw_parts(array.as_ptr() as *mut u8, array.len());
118
119        Ok(SliceGuard {
120            _array: array,
121            slice,
122        })
123    }
124}
125
126/// Wrapper around `&[u8]` derived from `jbyteArray` to prevent it from being auto-released.
127pub struct SliceGuard<'env, 'array> {
128    _array: AutoElements<'env, 'env, 'array, jbyte>,
129    slice: &'array [u8],
130}
131
132impl Deref for SliceGuard<'_, '_> {
133    type Target = [u8];
134
135    fn deref(&self) -> &Self::Target {
136        self.slice
137    }
138}
139
140#[repr(transparent)]
141pub struct Pointer<'a, T> {
142    pointer: jlong,
143    _phantom: PhantomData<&'a T>,
144}
145
146impl<T> Default for Pointer<'_, T> {
147    fn default() -> Self {
148        Self {
149            pointer: 0,
150            _phantom: Default::default(),
151        }
152    }
153}
154
155impl<T> From<T> for Pointer<'static, T> {
156    fn from(value: T) -> Self {
157        Pointer {
158            pointer: Box::into_raw(Box::new(value)) as jlong,
159            _phantom: PhantomData,
160        }
161    }
162}
163
164impl<'a, T> Pointer<'a, T> {
165    fn as_ref(&self) -> &'a T {
166        assert!(self.pointer != 0);
167        unsafe { &*(self.pointer as *const T) }
168    }
169
170    fn as_mut(&mut self) -> &'a mut T {
171        assert!(self.pointer != 0);
172        unsafe { &mut *(self.pointer as *mut T) }
173    }
174}
175
176/// A pointer that owns the object it points to.
177///
178/// Note that dropping an `OwnedPointer` does not release the object.
179/// Instead, you should call [`OwnedPointer::release`] manually.
180pub type OwnedPointer<T> = Pointer<'static, T>;
181
182impl<T> OwnedPointer<T> {
183    /// Consume `self` and return the pointer value. Used for passing to JNI.
184    pub fn into_pointer(self) -> jlong {
185        self.pointer
186    }
187
188    /// Release the object behind the pointer.
189    fn release(self) {
190        tracing::debug!(
191            type_name = std::any::type_name::<T>(),
192            address = %format_args!("{:x}", self.pointer),
193            "release jni OwnedPointer"
194        );
195        assert!(self.pointer != 0);
196        unsafe { drop(Box::from_raw(self.pointer as *mut T)) }
197    }
198}
199
200/// In most Jni interfaces, the first parameter is `JNIEnv`, and the second parameter is `JClass`.
201/// This struct simply encapsulates the two common parameters into a single struct for simplicity.
202#[repr(C)]
203pub struct EnvParam<'a> {
204    env: JNIEnv<'a>,
205    class: JClass<'a>,
206}
207
208impl<'a> Deref for EnvParam<'a> {
209    type Target = JNIEnv<'a>;
210
211    fn deref(&self) -> &Self::Target {
212        &self.env
213    }
214}
215
216impl DerefMut for EnvParam<'_> {
217    fn deref_mut(&mut self) -> &mut Self::Target {
218        &mut self.env
219    }
220}
221
222impl<'a> EnvParam<'a> {
223    pub fn get_class(&self) -> &JClass<'a> {
224        &self.class
225    }
226}
227
228pub fn execute_and_catch<'env, F, Ret>(mut env: EnvParam<'env>, inner: F) -> Ret
229where
230    F: FnOnce(&mut EnvParam<'env>) -> Result<Ret>,
231    Ret: Default + 'env,
232{
233    match rw_catch_unwind(std::panic::AssertUnwindSafe(|| inner(&mut env))) {
234        Ok(Ok(ret)) => ret,
235        Ok(Err(e)) => {
236            match e {
237                BindingError::Jni {
238                    error: jni::errors::Error::JavaException,
239                    backtrace,
240                } => {
241                    tracing::error!("get JavaException thrown from: {:?}", backtrace);
242                    // the exception is already thrown. No need to throw again
243                }
244                _ => {
245                    env.throw(format!("get error while processing: {:?}", e.as_report()))
246                        .expect("should be able to throw");
247                }
248            }
249            Ret::default()
250        }
251        Err(e) => {
252            env.throw(format!("panic while processing: {:?}", e))
253                .expect("should be able to throw");
254            Ret::default()
255        }
256    }
257}
258
259#[derive(Default)]
260struct JavaClassMethodCache {
261    big_decimal_ctor: OnceLock<(GlobalRef, JMethodID)>,
262
263    timestamp_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
264    timestamptz_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
265    date_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
266    time_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
267    instant_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
268    utc: OnceLock<GlobalRef>,
269}
270
271mod opaque_type {
272    use super::*;
273    // TODO: may only return a RowRef
274    pub type StreamChunkRowIterator<'a> = impl Iterator<Item = (Op, OwnedRow)> + 'a;
275
276    impl<'a> JavaBindingIteratorInner<'a> {
277        #[define_opaque(StreamChunkRowIterator)]
278        pub(super) fn from_chunk(chunk: &'a StreamChunk) -> JavaBindingIteratorInner<'a> {
279            JavaBindingIteratorInner::StreamChunk(
280                chunk
281                    .rows()
282                    .map(|(op, row)| (op.to_protobuf(), row.to_owned_row())),
283            )
284        }
285    }
286}
287pub use opaque_type::StreamChunkRowIterator;
288pub type HummockJavaBindingIterator = BoxStream<'static, anyhow::Result<(Bytes, OwnedRow)>>;
289pub enum JavaBindingIteratorInner<'a> {
290    Hummock(HummockJavaBindingIterator),
291    StreamChunk(StreamChunkRowIterator<'a>),
292}
293
294enum RowExtra {
295    Op(Op),
296    Key(Bytes),
297}
298
299impl RowExtra {
300    fn as_op(&self) -> Op {
301        match self {
302            RowExtra::Op(op) => *op,
303            RowExtra::Key(_) => unreachable!("should be op"),
304        }
305    }
306
307    fn as_key(&self) -> &Bytes {
308        match self {
309            RowExtra::Key(key) => key,
310            RowExtra::Op(_) => unreachable!("should be key"),
311        }
312    }
313}
314
315struct RowCursor {
316    row: OwnedRow,
317    extra: RowExtra,
318}
319
320pub struct JavaBindingIterator<'a> {
321    inner: JavaBindingIteratorInner<'a>,
322    cursor: Option<RowCursor>,
323    class_cache: JavaClassMethodCache,
324}
325
326impl JavaBindingIterator<'static> {
327    pub fn new_hummock_iter(iter: HummockJavaBindingIterator) -> Self {
328        Self {
329            inner: JavaBindingIteratorInner::Hummock(iter),
330            cursor: None,
331            class_cache: Default::default(),
332        }
333    }
334}
335
336impl Deref for JavaBindingIterator<'_> {
337    type Target = OwnedRow;
338
339    fn deref(&self) -> &Self::Target {
340        &self
341            .cursor
342            .as_ref()
343            .expect("should exist when call row methods")
344            .row
345    }
346}
347
348#[unsafe(no_mangle)]
349extern "system" fn Java_com_risingwave_java_binding_Binding_defaultVnodeCount(
350    _env: EnvParam<'_>,
351) -> jint {
352    VirtualNode::COUNT_FOR_COMPAT as jint
353}
354
355#[cfg_or_panic(not(madsim))]
356#[unsafe(no_mangle)]
357extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNewStreamChunk<'a>(
358    env: EnvParam<'a>,
359    chunk: Pointer<'a, StreamChunk>,
360) -> Pointer<'static, JavaBindingIterator<'a>> {
361    execute_and_catch(env, move |_env| {
362        let iter = JavaBindingIterator {
363            inner: JavaBindingIteratorInner::from_chunk(chunk.as_ref()),
364            cursor: None,
365            class_cache: Default::default(),
366        };
367        Ok(iter.into())
368    })
369}
370
371#[cfg_or_panic(not(madsim))]
372#[unsafe(no_mangle)]
373extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNext<'a>(
374    env: EnvParam<'a>,
375    mut pointer: Pointer<'a, JavaBindingIterator<'a>>,
376) -> jboolean {
377    execute_and_catch(env, move |_env| {
378        let iter = pointer.as_mut();
379        match &mut iter.inner {
380            JavaBindingIteratorInner::Hummock(hummock_iter) => {
381                match JAVA_BINDING_ASYNC_RUNTIME.block_on(hummock_iter.try_next())? {
382                    None => {
383                        iter.cursor = None;
384                        Ok(JNI_FALSE)
385                    }
386                    Some((key, row)) => {
387                        iter.cursor = Some(RowCursor {
388                            row,
389                            extra: RowExtra::Key(key),
390                        });
391                        Ok(JNI_TRUE)
392                    }
393                }
394            }
395            JavaBindingIteratorInner::StreamChunk(stream_chunk_iter) => {
396                match stream_chunk_iter.next() {
397                    None => {
398                        iter.cursor = None;
399                        Ok(JNI_FALSE)
400                    }
401                    Some((op, row)) => {
402                        iter.cursor = Some(RowCursor {
403                            row,
404                            extra: RowExtra::Op(op),
405                        });
406                        Ok(JNI_TRUE)
407                    }
408                }
409            }
410        }
411    })
412}
413
414#[unsafe(no_mangle)]
415extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorClose<'a>(
416    _env: EnvParam<'a>,
417    pointer: OwnedPointer<JavaBindingIterator<'a>>,
418) {
419    pointer.release()
420}
421
422#[unsafe(no_mangle)]
423extern "system" fn Java_com_risingwave_java_binding_Binding_newStreamChunkFromPayload<'a>(
424    env: EnvParam<'a>,
425    stream_chunk_payload: JByteArray<'a>,
426) -> Pointer<'static, StreamChunk> {
427    execute_and_catch(env, move |env| {
428        let prost_stream_chumk =
429            Message::decode(to_guarded_slice(&stream_chunk_payload, env)?.deref())?;
430        Ok(StreamChunk::from_protobuf(&prost_stream_chumk)?.into())
431    })
432}
433
434#[unsafe(no_mangle)]
435extern "system" fn Java_com_risingwave_java_binding_Binding_newStreamChunkFromPretty<'a>(
436    env: EnvParam<'a>,
437    str: JString<'a>,
438) -> Pointer<'static, StreamChunk> {
439    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
440        Ok(StreamChunk::from_pretty(env.get_string(&str)?.to_str().unwrap()).into())
441    })
442}
443
444#[unsafe(no_mangle)]
445extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkClose(
446    _env: EnvParam<'_>,
447    chunk: OwnedPointer<StreamChunk>,
448) {
449    chunk.release()
450}
451
452#[unsafe(no_mangle)]
453extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetKey<'a>(
454    env: EnvParam<'a>,
455    pointer: Pointer<'a, JavaBindingIterator<'a>>,
456) -> JByteArray<'a> {
457    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
458        Ok(env.byte_array_from_slice(
459            pointer
460                .as_ref()
461                .cursor
462                .as_ref()
463                .expect("should exists when call get key")
464                .extra
465                .as_key()
466                .as_ref(),
467        )?)
468    })
469}
470
471#[unsafe(no_mangle)]
472extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetOp<'a>(
473    env: EnvParam<'a>,
474    pointer: Pointer<'a, JavaBindingIterator<'a>>,
475) -> jint {
476    execute_and_catch(env, move |_env| {
477        Ok(pointer
478            .as_ref()
479            .cursor
480            .as_ref()
481            .expect("should exist when call get op")
482            .extra
483            .as_op() as jint)
484    })
485}
486
487#[unsafe(no_mangle)]
488extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorIsNull<'a>(
489    env: EnvParam<'a>,
490    pointer: Pointer<'a, JavaBindingIterator<'a>>,
491    idx: jint,
492) -> jboolean {
493    execute_and_catch(env, move |_env| {
494        Ok(pointer.as_ref().datum_at(idx as usize).is_none() as jboolean)
495    })
496}
497
498#[unsafe(no_mangle)]
499extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt16Value<'a>(
500    env: EnvParam<'a>,
501    pointer: Pointer<'a, JavaBindingIterator<'a>>,
502    idx: jint,
503) -> jshort {
504    execute_and_catch(env, move |_env| {
505        Ok(pointer
506            .as_ref()
507            .datum_at(idx as usize)
508            .unwrap()
509            .into_int16())
510    })
511}
512
513#[unsafe(no_mangle)]
514extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt32Value<'a>(
515    env: EnvParam<'a>,
516    pointer: Pointer<'a, JavaBindingIterator<'a>>,
517    idx: jint,
518) -> jint {
519    execute_and_catch(env, move |_env| {
520        Ok(pointer
521            .as_ref()
522            .datum_at(idx as usize)
523            .unwrap()
524            .into_int32())
525    })
526}
527
528#[unsafe(no_mangle)]
529extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt64Value<'a>(
530    env: EnvParam<'a>,
531    pointer: Pointer<'a, JavaBindingIterator<'a>>,
532    idx: jint,
533) -> jlong {
534    execute_and_catch(env, move |_env| {
535        Ok(pointer
536            .as_ref()
537            .datum_at(idx as usize)
538            .unwrap()
539            .into_int64())
540    })
541}
542
543#[unsafe(no_mangle)]
544extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetFloatValue<'a>(
545    env: EnvParam<'a>,
546    pointer: Pointer<'a, JavaBindingIterator<'a>>,
547    idx: jint,
548) -> jfloat {
549    execute_and_catch(env, move |_env| {
550        Ok(pointer
551            .as_ref()
552            .datum_at(idx as usize)
553            .unwrap()
554            .into_float32()
555            .into())
556    })
557}
558
559#[unsafe(no_mangle)]
560extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDoubleValue<'a>(
561    env: EnvParam<'a>,
562    pointer: Pointer<'a, JavaBindingIterator<'a>>,
563    idx: jint,
564) -> jdouble {
565    execute_and_catch(env, move |_env| {
566        Ok(pointer
567            .as_ref()
568            .datum_at(idx as usize)
569            .unwrap()
570            .into_float64()
571            .into())
572    })
573}
574
575#[unsafe(no_mangle)]
576extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetBooleanValue<'a>(
577    env: EnvParam<'a>,
578    pointer: Pointer<'a, JavaBindingIterator<'a>>,
579    idx: jint,
580) -> jboolean {
581    execute_and_catch(env, move |_env| {
582        Ok(pointer.as_ref().datum_at(idx as usize).unwrap().into_bool() as jboolean)
583    })
584}
585
586#[unsafe(no_mangle)]
587extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetStringValue<'a>(
588    env: EnvParam<'a>,
589    pointer: Pointer<'a, JavaBindingIterator<'a>>,
590    idx: jint,
591) -> JString<'a> {
592    execute_and_catch(env, move |env: &mut EnvParam<'a>| {
593        Ok(env.new_string(pointer.as_ref().datum_at(idx as usize).unwrap().into_utf8())?)
594    })
595}
596
597#[unsafe(no_mangle)]
598extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetIntervalValue<'a>(
599    env: EnvParam<'a>,
600    pointer: Pointer<'a, JavaBindingIterator<'a>>,
601    idx: jint,
602) -> JString<'a> {
603    execute_and_catch(env, move |env: &mut EnvParam<'a>| {
604        let interval = pointer
605            .as_ref()
606            .datum_at(idx as usize)
607            .unwrap()
608            .into_interval()
609            .as_iso_8601();
610        Ok(env.new_string(interval)?)
611    })
612}
613
614#[unsafe(no_mangle)]
615extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetJsonbValue<'a>(
616    env: EnvParam<'a>,
617    pointer: Pointer<'a, JavaBindingIterator<'a>>,
618    idx: jint,
619) -> JString<'a> {
620    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
621        let jsonb = pointer
622            .as_ref()
623            .datum_at(idx as usize)
624            .unwrap()
625            .into_jsonb()
626            .to_string();
627        Ok(env.new_string(jsonb)?)
628    })
629}
630
631#[unsafe(no_mangle)]
632extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue<'a>(
633    env: EnvParam<'a>,
634    pointer: Pointer<'a, JavaBindingIterator<'a>>,
635    idx: jint,
636) -> JObject<'a> {
637    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
638        let value = pointer
639            .as_ref()
640            .datum_at(idx as usize)
641            .unwrap()
642            .into_timestamp();
643
644        let sig = gen_jni_sig!(java.time.LocalDateTime of(int year, int month, int dayOfMonth, int hour, int minute, int second, int nanoOfSecond));
645
646        let (timestamp_class_ref, constructor) = pointer
647            .as_ref()
648            .class_cache
649            .timestamp_ctor
650            .get_or_try_init(|| {
651                let cls = env.find_class(gen_class_name!(java.time.LocalDateTime))?;
652                let init_method = env.get_static_method_id(&cls, "of", sig)?;
653                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
654            })?;
655        unsafe {
656            let JValueOwned::Object(timestamp_obj) = env.call_static_method_unchecked(
657                <&JClass<'_>>::from(timestamp_class_ref.as_obj()),
658                *constructor,
659                ReturnType::Object,
660                &[
661                    jvalue { i: value.0.year() },
662                    jvalue {
663                        i: value.0.month() as i32,
664                    },
665                    jvalue {
666                        i: value.0.day() as i32,
667                    },
668                    jvalue {
669                        i: value.0.hour() as i32,
670                    },
671                    jvalue {
672                        i: value.0.minute() as i32,
673                    },
674                    jvalue {
675                        i: value.0.second() as i32,
676                    },
677                    jvalue {
678                        i: value.0.nanosecond() as i32,
679                    },
680                ],
681            )?
682            else {
683                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
684                    name: "of".to_owned(),
685                    sig: sig.into(),
686                }));
687            };
688            Ok(timestamp_obj)
689        }
690    })
691}
692
693#[unsafe(no_mangle)]
694extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestamptzValue<'a>(
695    env: EnvParam<'a>,
696    pointer: Pointer<'a, JavaBindingIterator<'a>>,
697    idx: jint,
698) -> JObject<'a> {
699    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
700        let value = pointer
701            .as_ref()
702            .datum_at(idx as usize)
703            .unwrap()
704            .into_timestamptz();
705
706        let instant_sig =
707            gen_jni_sig!(java.time.Instant ofEpochSecond(long epochSecond, long nanoAdjustment));
708
709        let (instant_class_ref, instant_constructor) = pointer
710            .as_ref()
711            .class_cache
712            .instant_ctor
713            .get_or_try_init(|| {
714                let cls = env.find_class(gen_class_name!(java.time.Instant))?;
715                let init_method = env.get_static_method_id(&cls, "ofEpochSecond", instant_sig)?;
716                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
717            })?;
718        let instant_obj = unsafe {
719            let JValueOwned::Object(instant_obj) = env.call_static_method_unchecked(
720                <&JClass<'_>>::from(instant_class_ref.as_obj()),
721                *instant_constructor,
722                ReturnType::Object,
723                &[
724                    jvalue {
725                        j: value.timestamp(),
726                    },
727                    jvalue {
728                        j: value.timestamp_subsec_nanos() as i64,
729                    },
730                ],
731            )?
732            else {
733                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
734                    name: "ofEpochSecond".to_owned(),
735                    sig: instant_sig.into(),
736                }));
737            };
738            instant_obj
739        };
740
741        let utc_ref = pointer.as_ref().class_cache.utc.get_or_try_init(|| {
742            let cls = env.find_class(gen_class_name!(java.time.ZoneOffset))?;
743            let utc = env
744                .get_static_field(&cls, "UTC", gen_jni_type_sig!(java.time.ZoneOffset))?
745                .l()?;
746            env.new_global_ref(utc)
747        })?;
748
749        let sig = gen_jni_sig!(java.time.OffsetDateTime ofInstant(java.time.Instant instant, java.time.ZoneId zone));
750
751        let (timestamptz_class_ref, constructor) = pointer
752            .as_ref()
753            .class_cache
754            .timestamptz_ctor
755            .get_or_try_init(|| {
756                let cls = env.find_class(gen_class_name!(java.time.OffsetDateTime))?;
757                let init_method = env.get_static_method_id(&cls, "ofInstant", sig)?;
758                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
759            })?;
760        unsafe {
761            let JValueOwned::Object(timestamptz_obj) = env.call_static_method_unchecked(
762                <&JClass<'_>>::from(timestamptz_class_ref.as_obj()),
763                *constructor,
764                ReturnType::Object,
765                &[
766                    jvalue {
767                        l: instant_obj.as_raw(),
768                    },
769                    jvalue {
770                        l: utc_ref.as_obj().as_raw(),
771                    },
772                ],
773            )?
774            else {
775                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
776                    name: "ofInstant".to_owned(),
777                    sig: sig.into(),
778                }));
779            };
780            Ok(timestamptz_obj)
781        }
782    })
783}
784
785#[unsafe(no_mangle)]
786extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDecimalValue<'a>(
787    env: EnvParam<'a>,
788    pointer: Pointer<'a, JavaBindingIterator<'a>>,
789    idx: jint,
790) -> JObject<'a> {
791    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
792        let decimal_value = pointer
793            .as_ref()
794            .datum_at(idx as usize)
795            .unwrap()
796            .into_decimal();
797
798        match decimal_value {
799            Decimal::NaN | Decimal::NegativeInf | Decimal::PositiveInf => {
800                return Ok(JObject::null());
801            }
802            Decimal::Normalized(_) => {}
803        };
804
805        let value = decimal_value.to_string();
806        let string_value = env.new_string(value)?;
807        let (decimal_class_ref, constructor) = pointer
808            .as_ref()
809            .class_cache
810            .big_decimal_ctor
811            .get_or_try_init(|| {
812                let cls = env.find_class("java/math/BigDecimal")?;
813                let init_method = env.get_method_id(&cls, "<init>", "(Ljava/lang/String;)V")?;
814                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
815            })?;
816        unsafe {
817            let decimal_class = <&JClass<'_>>::from(decimal_class_ref.as_obj());
818            let date_obj = env.new_object_unchecked(
819                decimal_class,
820                *constructor,
821                &[jvalue {
822                    l: string_value.into_raw(),
823                }],
824            )?;
825            Ok(date_obj)
826        }
827    })
828}
829
830#[unsafe(no_mangle)]
831extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDateValue<'a>(
832    env: EnvParam<'a>,
833    pointer: Pointer<'a, JavaBindingIterator<'a>>,
834    idx: jint,
835) -> JObject<'a> {
836    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
837        let value = pointer.as_ref().datum_at(idx as usize).unwrap().into_date();
838        let epoch_days = (value.0 - DateTime::UNIX_EPOCH.date_naive()).num_days();
839
840        let sig = gen_jni_sig!(java.time.LocalDate ofEpochDay(long));
841
842        let (date_class_ref, constructor) =
843            pointer.as_ref().class_cache.date_ctor.get_or_try_init(|| {
844                let cls = env.find_class(gen_class_name!(java.time.LocalDate))?;
845                let init_method = env.get_static_method_id(&cls, "ofEpochDay", sig)?;
846                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
847            })?;
848        unsafe {
849            let JValueOwned::Object(date_obj) = env.call_static_method_unchecked(
850                <&JClass<'_>>::from(date_class_ref.as_obj()),
851                *constructor,
852                ReturnType::Object,
853                &[jvalue { j: epoch_days }],
854            )?
855            else {
856                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
857                    name: "ofEpochDay".to_owned(),
858                    sig: sig.into(),
859                }));
860            };
861            Ok(date_obj)
862        }
863    })
864}
865
866#[unsafe(no_mangle)]
867extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimeValue<'a>(
868    env: EnvParam<'a>,
869    pointer: Pointer<'a, JavaBindingIterator<'a>>,
870    idx: jint,
871) -> JObject<'a> {
872    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
873        let value = pointer.as_ref().datum_at(idx as usize).unwrap().into_time();
874
875        let sig = gen_jni_sig!(java.time.LocalTime of(int hour, int minute, int second, int nanoOfSecond));
876
877        let (time_class_ref, constructor) =
878            pointer.as_ref().class_cache.time_ctor.get_or_try_init(|| {
879                let cls = env.find_class(gen_class_name!(java.time.LocalTime))?;
880                let init_method = env.get_static_method_id(&cls, "of", sig)?;
881                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
882            })?;
883        unsafe {
884            let JValueOwned::Object(time_obj) = env.call_static_method_unchecked(
885                <&JClass<'_>>::from(time_class_ref.as_obj()),
886                *constructor,
887                ReturnType::Object,
888                &[
889                    jvalue {
890                        i: value.0.hour() as i32,
891                    },
892                    jvalue {
893                        i: value.0.minute() as i32,
894                    },
895                    jvalue {
896                        i: value.0.second() as i32,
897                    },
898                    jvalue {
899                        i: value.0.nanosecond() as i32,
900                    },
901                ],
902            )?
903            else {
904                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
905                    name: "of".to_owned(),
906                    sig: sig.into(),
907                }));
908            };
909            Ok(time_obj)
910        }
911    })
912}
913
914#[unsafe(no_mangle)]
915extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetByteaValue<'a>(
916    env: EnvParam<'a>,
917    pointer: Pointer<'a, JavaBindingIterator<'a>>,
918    idx: jint,
919) -> JByteArray<'a> {
920    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
921        let bytes = pointer
922            .as_ref()
923            .datum_at(idx as usize)
924            .unwrap()
925            .into_bytea();
926        Ok(env.byte_array_from_slice(bytes)?)
927    })
928}
929
930#[unsafe(no_mangle)]
931extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValue<'a>(
932    env: EnvParam<'a>,
933    pointer: Pointer<'a, JavaBindingIterator<'a>>,
934    idx: jint,
935    class: JClass<'a>,
936) -> JObject<'a> {
937    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
938        let elems = pointer
939            .as_ref()
940            .datum_at(idx as usize)
941            .unwrap()
942            .into_list()
943            .iter();
944
945        // convert the Rust elements to a Java object array (Object[])
946        let jarray = env.new_object_array(elems.len() as jsize, &class, JObject::null())?;
947
948        for (i, ele) in elems.enumerate() {
949            let index = i as jsize;
950            match ele {
951                None => env.set_object_array_element(&jarray, i as jsize, JObject::null())?,
952                Some(val) => match val {
953                    ScalarRefImpl::Int16(v) => {
954                        let o = call_static_method!(
955                            env,
956                            {Short},
957                            {Short	valueOf(short s)},
958                            v
959                        )?;
960                        env.set_object_array_element(&jarray, index, &o)?;
961                    }
962                    ScalarRefImpl::Int32(v) => {
963                        let o = call_static_method!(
964                            env,
965                            {Integer},
966                            {Integer	valueOf(int i)},
967                            v
968                        )?;
969                        env.set_object_array_element(&jarray, index, &o)?;
970                    }
971                    ScalarRefImpl::Int64(v) => {
972                        let o = call_static_method!(
973                            env,
974                            {Long},
975                            {Long	valueOf(long l)},
976                            v
977                        )?;
978                        env.set_object_array_element(&jarray, index, &o)?;
979                    }
980                    ScalarRefImpl::Float32(v) => {
981                        let o = call_static_method!(
982                            env,
983                            {Float},
984                            {Float	valueOf(float f)},
985                            v.into_inner()
986                        )?;
987                        env.set_object_array_element(&jarray, index, &o)?;
988                    }
989                    ScalarRefImpl::Float64(v) => {
990                        let o = call_static_method!(
991                            env,
992                            {Double},
993                            {Double	valueOf(double d)},
994                            v.into_inner()
995                        )?;
996                        env.set_object_array_element(&jarray, index, &o)?;
997                    }
998                    ScalarRefImpl::Utf8(v) => {
999                        let obj = env.new_string(v)?;
1000                        env.set_object_array_element(&jarray, index, obj)?
1001                    }
1002                    _ => env.set_object_array_element(&jarray, index, JObject::null())?,
1003                },
1004            }
1005        }
1006        let output = unsafe { JObject::from_raw(jarray.into_raw()) };
1007        Ok(output)
1008    })
1009}
1010
1011pub type JniSenderType<T> = Sender<anyhow::Result<T>>;
1012pub type JniReceiverType<T> = Receiver<T>;
1013
1014/// Send messages to the channel received by `CdcSplitReader`.
1015/// If msg is null, just check whether the channel is closed.
1016/// Return true if sending is successful, otherwise, return false so that caller can stop
1017/// gracefully.
1018#[unsafe(no_mangle)]
1019extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>(
1020    env: EnvParam<'a>,
1021    channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
1022    msg: JByteArray<'a>,
1023) -> jboolean {
1024    execute_and_catch(env, move |env| {
1025        // If msg is null means just check whether channel is closed.
1026        if msg.is_null() {
1027            if channel.as_ref().is_closed() {
1028                return Ok(JNI_FALSE);
1029            } else {
1030                return Ok(JNI_TRUE);
1031            }
1032        }
1033
1034        let get_event_stream_response: GetEventStreamResponse =
1035            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1036
1037        match channel
1038            .as_ref()
1039            .blocking_send(Ok(get_event_stream_response))
1040        {
1041            Ok(_) => Ok(JNI_TRUE),
1042            Err(e) => {
1043                tracing::info!(error = %e.as_report(), "send error");
1044                Ok(JNI_FALSE)
1045            }
1046        }
1047    })
1048}
1049
1050#[unsafe(no_mangle)]
1051extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceErrorToChannel<'a>(
1052    env: EnvParam<'a>,
1053    channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
1054    msg: JString<'a>,
1055) -> jboolean {
1056    execute_and_catch(env, move |env| {
1057        let ret = env.get_string(&msg);
1058        match ret {
1059            Ok(str) => {
1060                let err_msg: String = str.into();
1061                match channel.as_ref().blocking_send(Err(anyhow!(err_msg))) {
1062                    Ok(_) => Ok(JNI_TRUE),
1063                    Err(e) => {
1064                        tracing::info!(error = ?e.as_report(), "send error");
1065                        Ok(JNI_FALSE)
1066                    }
1067                }
1068            }
1069            Err(err) => {
1070                if msg.is_null() {
1071                    tracing::warn!("source error message is null");
1072                    Ok(JNI_TRUE)
1073                } else {
1074                    tracing::error!(error = ?err.as_report(), "source error message should be a java string");
1075                    Ok(JNI_FALSE)
1076                }
1077            }
1078        }
1079    })
1080}
1081
1082#[unsafe(no_mangle)]
1083extern "system" fn Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose(
1084    _env: EnvParam<'_>,
1085    channel: OwnedPointer<JniSenderType<GetEventStreamResponse>>,
1086) {
1087    channel.release();
1088}
1089
1090pub enum JniSinkWriterStreamRequest {
1091    PbRequest(SinkWriterStreamRequest),
1092    Chunk {
1093        epoch: u64,
1094        batch_id: u64,
1095        chunk: StreamChunk,
1096    },
1097}
1098
1099impl From<SinkWriterStreamRequest> for JniSinkWriterStreamRequest {
1100    fn from(value: SinkWriterStreamRequest) -> Self {
1101        Self::PbRequest(value)
1102    }
1103}
1104
1105#[unsafe(no_mangle)]
1106pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel<
1107    'a,
1108>(
1109    env: EnvParam<'a>,
1110    mut channel: Pointer<'a, JniReceiverType<JniSinkWriterStreamRequest>>,
1111) -> JObject<'a> {
1112    execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
1113        Some(msg) => {
1114            let obj = match msg {
1115                JniSinkWriterStreamRequest::PbRequest(request) => {
1116                    let bytes = env.byte_array_from_slice(&Message::encode_to_vec(&request))?;
1117                    let jobj = JObject::from(bytes);
1118                    call_static_method!(
1119                        env,
1120                        {com.risingwave.java.binding.JniSinkWriterStreamRequest},
1121                        {com.risingwave.java.binding.JniSinkWriterStreamRequest fromSerializedPayload(byte[] payload)},
1122                        &jobj
1123                    )?
1124                }
1125                JniSinkWriterStreamRequest::Chunk {
1126                    epoch,
1127                    batch_id,
1128                    chunk,
1129                } => {
1130                    let pointer = Box::into_raw(Box::new(chunk));
1131                    call_static_method!(
1132                        env,
1133                        {com.risingwave.java.binding.JniSinkWriterStreamRequest},
1134                        {com.risingwave.java.binding.JniSinkWriterStreamRequest fromStreamChunkOwnedPointer(long pointer, long epoch, long batchId)},
1135                        pointer as u64, epoch, batch_id
1136                    )
1137                    .inspect_err(|_| unsafe {
1138                        // release the stream chunk on err
1139                        drop(Box::from_raw(pointer));
1140                    })?
1141                }
1142            };
1143            Ok(obj)
1144        }
1145        None => Ok(JObject::null()),
1146    })
1147}
1148
1149#[unsafe(no_mangle)]
1150pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterResponseToChannel<
1151    'a,
1152>(
1153    env: EnvParam<'a>,
1154    channel: Pointer<'a, JniSenderType<SinkWriterStreamResponse>>,
1155    msg: JByteArray<'a>,
1156) -> jboolean {
1157    execute_and_catch(env, move |env| {
1158        let sink_writer_stream_response: SinkWriterStreamResponse =
1159            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1160
1161        match channel
1162            .as_ref()
1163            .blocking_send(Ok(sink_writer_stream_response))
1164        {
1165            Ok(_) => Ok(JNI_TRUE),
1166            Err(e) => {
1167                tracing::info!(error = ?e.as_report(), "send error");
1168                Ok(JNI_FALSE)
1169            }
1170        }
1171    })
1172}
1173
1174#[unsafe(no_mangle)]
1175pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterErrorToChannel<'a>(
1176    env: EnvParam<'a>,
1177    channel: Pointer<'a, Sender<anyhow::Result<SinkWriterStreamResponse>>>,
1178    msg: JString<'a>,
1179) -> jboolean {
1180    execute_and_catch(env, move |env| {
1181        let ret = env.get_string(&msg);
1182        match ret {
1183            Ok(str) => {
1184                let err_msg: String = str.into();
1185                match channel.as_ref().blocking_send(Err(anyhow!(err_msg))) {
1186                    Ok(_) => Ok(JNI_TRUE),
1187                    Err(e) => {
1188                        tracing::info!(error = ?e.as_report(), "send error");
1189                        Ok(JNI_FALSE)
1190                    }
1191                }
1192            }
1193            Err(err) => {
1194                if msg.is_null() {
1195                    tracing::warn!("sink error message is null");
1196                    Ok(JNI_TRUE)
1197                } else {
1198                    tracing::error!(error = ?err.as_report(), "sink error message should be a java string");
1199                    Ok(JNI_FALSE)
1200                }
1201            }
1202        }
1203    })
1204}
1205
1206#[unsafe(no_mangle)]
1207pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkCoordinatorRequestFromChannel<
1208    'a,
1209>(
1210    env: EnvParam<'a>,
1211    mut channel: Pointer<'a, JniReceiverType<SinkCoordinatorStreamRequest>>,
1212) -> JByteArray<'a> {
1213    execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
1214        Some(msg) => {
1215            let bytes = env
1216                .byte_array_from_slice(&Message::encode_to_vec(&msg))
1217                .unwrap();
1218            Ok(bytes)
1219        }
1220        None => Ok(JObject::null().into()),
1221    })
1222}
1223
1224#[unsafe(no_mangle)]
1225pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordinatorResponseToChannel<
1226    'a,
1227>(
1228    env: EnvParam<'a>,
1229    channel: Pointer<'a, JniSenderType<SinkCoordinatorStreamResponse>>,
1230    msg: JByteArray<'a>,
1231) -> jboolean {
1232    execute_and_catch(env, move |env| {
1233        let sink_coordinator_stream_response: SinkCoordinatorStreamResponse =
1234            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1235
1236        match channel
1237            .as_ref()
1238            .blocking_send(Ok(sink_coordinator_stream_response))
1239        {
1240            Ok(_) => Ok(JNI_TRUE),
1241            Err(e) => {
1242                tracing::info!(error = ?e.as_report(), "send error");
1243                Ok(JNI_FALSE)
1244            }
1245        }
1246    })
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251    use risingwave_common::types::Timestamptz;
1252
1253    /// make sure that the [`ScalarRefImpl::Int64`] received by
1254    /// [`Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue`]
1255    /// is of type [`DataType::Timestamptz`] stored in microseconds
1256    #[test]
1257    fn test_timestamptz_to_i64() {
1258        assert_eq!(
1259            "2023-06-01 09:45:00+08:00".parse::<Timestamptz>().unwrap(),
1260            Timestamptz::from_micros(1_685_583_900_000_000)
1261        );
1262    }
1263}