risingwave_jni_core/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(error_generic_member_access)]
16#![feature(once_cell_try)]
17#![feature(type_alias_impl_trait)]
18#![feature(try_blocks)]
19
20pub mod jvm_runtime;
21mod macros;
22mod tracing_slf4j;
23
24use std::backtrace::Backtrace;
25use std::marker::PhantomData;
26use std::ops::{Deref, DerefMut};
27use std::slice::from_raw_parts;
28use std::sync::{LazyLock, OnceLock};
29
30use anyhow::anyhow;
31use bytes::Bytes;
32use cfg_or_panic::cfg_or_panic;
33use chrono::{Datelike, NaiveDateTime, Timelike};
34use futures::TryStreamExt;
35use futures::stream::BoxStream;
36use jni::JNIEnv;
37use jni::objects::{
38    AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString,
39    JValueOwned, ReleaseMode,
40};
41use jni::signature::ReturnType;
42use jni::sys::{
43    JNI_FALSE, JNI_TRUE, jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue,
44};
45pub use paste::paste;
46use prost::{DecodeError, Message};
47use risingwave_common::array::{ArrayError, StreamChunk};
48use risingwave_common::hash::VirtualNode;
49use risingwave_common::row::{OwnedRow, Row};
50use risingwave_common::test_prelude::StreamChunkTestExt;
51use risingwave_common::types::{Decimal, ScalarRefImpl};
52use risingwave_common::util::panic::rw_catch_unwind;
53use risingwave_pb::connector_service::{
54    GetEventStreamResponse, SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse,
55    SinkWriterStreamRequest, SinkWriterStreamResponse,
56};
57use risingwave_pb::data::Op;
58use thiserror::Error;
59use thiserror_ext::AsReport;
60use tokio::runtime::Runtime;
61use tokio::sync::mpsc::{Receiver, Sender};
62use tracing_slf4j::*;
63
64pub static JAVA_BINDING_ASYNC_RUNTIME: LazyLock<Runtime> =
65    LazyLock::new(|| tokio::runtime::Runtime::new().unwrap());
66
67#[derive(Error, Debug)]
68pub enum BindingError {
69    #[error("JniError {error}")]
70    Jni {
71        #[from]
72        error: jni::errors::Error,
73        backtrace: Backtrace,
74    },
75
76    #[error("StorageError {error}")]
77    Storage {
78        #[from]
79        error: anyhow::Error,
80        backtrace: Backtrace,
81    },
82
83    #[error("DecodeError {error}")]
84    Decode {
85        #[from]
86        error: DecodeError,
87        backtrace: Backtrace,
88    },
89
90    #[error("StreamChunkArrayError {error}")]
91    StreamChunkArray {
92        #[from]
93        error: ArrayError,
94        backtrace: Backtrace,
95    },
96}
97
98type Result<T> = std::result::Result<T, BindingError>;
99
100pub fn to_guarded_slice<'array, 'env>(
101    array: &'array JByteArray<'env>,
102    env: &'array mut JNIEnv<'env>,
103) -> Result<SliceGuard<'env, 'array>> {
104    unsafe {
105        let array = env.get_array_elements(array, ReleaseMode::NoCopyBack)?;
106        let slice = from_raw_parts(array.as_ptr() as *mut u8, array.len());
107
108        Ok(SliceGuard {
109            _array: array,
110            slice,
111        })
112    }
113}
114
115/// Wrapper around `&[u8]` derived from `jbyteArray` to prevent it from being auto-released.
116pub struct SliceGuard<'env, 'array> {
117    _array: AutoElements<'env, 'env, 'array, jbyte>,
118    slice: &'array [u8],
119}
120
121impl Deref for SliceGuard<'_, '_> {
122    type Target = [u8];
123
124    fn deref(&self) -> &Self::Target {
125        self.slice
126    }
127}
128
129#[repr(transparent)]
130pub struct Pointer<'a, T> {
131    pointer: jlong,
132    _phantom: PhantomData<&'a T>,
133}
134
135impl<T> Default for Pointer<'_, T> {
136    fn default() -> Self {
137        Self {
138            pointer: 0,
139            _phantom: Default::default(),
140        }
141    }
142}
143
144impl<T> From<T> for Pointer<'static, T> {
145    fn from(value: T) -> Self {
146        Pointer {
147            pointer: Box::into_raw(Box::new(value)) as jlong,
148            _phantom: PhantomData,
149        }
150    }
151}
152
153impl<'a, T> Pointer<'a, T> {
154    fn as_ref(&self) -> &'a T {
155        assert!(self.pointer != 0);
156        unsafe { &*(self.pointer as *const T) }
157    }
158
159    fn as_mut(&mut self) -> &'a mut T {
160        assert!(self.pointer != 0);
161        unsafe { &mut *(self.pointer as *mut T) }
162    }
163}
164
165/// A pointer that owns the object it points to.
166///
167/// Note that dropping an `OwnedPointer` does not release the object.
168/// Instead, you should call [`OwnedPointer::release`] manually.
169pub type OwnedPointer<T> = Pointer<'static, T>;
170
171impl<T> OwnedPointer<T> {
172    /// Consume `self` and return the pointer value. Used for passing to JNI.
173    pub fn into_pointer(self) -> jlong {
174        self.pointer
175    }
176
177    /// Release the object behind the pointer.
178    fn release(self) {
179        tracing::debug!(
180            type_name = std::any::type_name::<T>(),
181            address = %format_args!("{:x}", self.pointer),
182            "release jni OwnedPointer"
183        );
184        assert!(self.pointer != 0);
185        unsafe { drop(Box::from_raw(self.pointer as *mut T)) }
186    }
187}
188
189/// In most Jni interfaces, the first parameter is `JNIEnv`, and the second parameter is `JClass`.
190/// This struct simply encapsulates the two common parameters into a single struct for simplicity.
191#[repr(C)]
192pub struct EnvParam<'a> {
193    env: JNIEnv<'a>,
194    class: JClass<'a>,
195}
196
197impl<'a> Deref for EnvParam<'a> {
198    type Target = JNIEnv<'a>;
199
200    fn deref(&self) -> &Self::Target {
201        &self.env
202    }
203}
204
205impl DerefMut for EnvParam<'_> {
206    fn deref_mut(&mut self) -> &mut Self::Target {
207        &mut self.env
208    }
209}
210
211impl<'a> EnvParam<'a> {
212    pub fn get_class(&self) -> &JClass<'a> {
213        &self.class
214    }
215}
216
217pub fn execute_and_catch<'env, F, Ret>(mut env: EnvParam<'env>, inner: F) -> Ret
218where
219    F: FnOnce(&mut EnvParam<'env>) -> Result<Ret>,
220    Ret: Default + 'env,
221{
222    match rw_catch_unwind(std::panic::AssertUnwindSafe(|| inner(&mut env))) {
223        Ok(Ok(ret)) => ret,
224        Ok(Err(e)) => {
225            match e {
226                BindingError::Jni {
227                    error: jni::errors::Error::JavaException,
228                    backtrace,
229                } => {
230                    tracing::error!("get JavaException thrown from: {:?}", backtrace);
231                    // the exception is already thrown. No need to throw again
232                }
233                _ => {
234                    env.throw(format!("get error while processing: {:?}", e.as_report()))
235                        .expect("should be able to throw");
236                }
237            }
238            Ret::default()
239        }
240        Err(e) => {
241            env.throw(format!("panic while processing: {:?}", e))
242                .expect("should be able to throw");
243            Ret::default()
244        }
245    }
246}
247
248#[derive(Default)]
249struct JavaClassMethodCache {
250    big_decimal_ctor: OnceLock<(GlobalRef, JMethodID)>,
251
252    timestamp_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
253    timestamptz_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
254    date_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
255    time_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
256    instant_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
257    utc: OnceLock<GlobalRef>,
258}
259
260mod opaque_type {
261    use super::*;
262    // TODO: may only return a RowRef
263    pub type StreamChunkRowIterator<'a> = impl Iterator<Item = (Op, OwnedRow)> + 'a;
264
265    impl<'a> JavaBindingIteratorInner<'a> {
266        pub(super) fn from_chunk(chunk: &'a StreamChunk) -> JavaBindingIteratorInner<'a> {
267            JavaBindingIteratorInner::StreamChunk(
268                chunk
269                    .rows()
270                    .map(|(op, row)| (op.to_protobuf(), row.to_owned_row())),
271            )
272        }
273    }
274}
275pub use opaque_type::StreamChunkRowIterator;
276pub type HummockJavaBindingIterator = BoxStream<'static, anyhow::Result<(Bytes, OwnedRow)>>;
277pub enum JavaBindingIteratorInner<'a> {
278    Hummock(HummockJavaBindingIterator),
279    StreamChunk(StreamChunkRowIterator<'a>),
280}
281
282enum RowExtra {
283    Op(Op),
284    Key(Bytes),
285}
286
287impl RowExtra {
288    fn as_op(&self) -> Op {
289        match self {
290            RowExtra::Op(op) => *op,
291            RowExtra::Key(_) => unreachable!("should be op"),
292        }
293    }
294
295    fn as_key(&self) -> &Bytes {
296        match self {
297            RowExtra::Key(key) => key,
298            RowExtra::Op(_) => unreachable!("should be key"),
299        }
300    }
301}
302
303struct RowCursor {
304    row: OwnedRow,
305    extra: RowExtra,
306}
307
308pub struct JavaBindingIterator<'a> {
309    inner: JavaBindingIteratorInner<'a>,
310    cursor: Option<RowCursor>,
311    class_cache: JavaClassMethodCache,
312}
313
314impl JavaBindingIterator<'static> {
315    pub fn new_hummock_iter(iter: HummockJavaBindingIterator) -> Self {
316        Self {
317            inner: JavaBindingIteratorInner::Hummock(iter),
318            cursor: None,
319            class_cache: Default::default(),
320        }
321    }
322}
323
324impl Deref for JavaBindingIterator<'_> {
325    type Target = OwnedRow;
326
327    fn deref(&self) -> &Self::Target {
328        &self
329            .cursor
330            .as_ref()
331            .expect("should exist when call row methods")
332            .row
333    }
334}
335
336#[unsafe(no_mangle)]
337extern "system" fn Java_com_risingwave_java_binding_Binding_defaultVnodeCount(
338    _env: EnvParam<'_>,
339) -> jint {
340    VirtualNode::COUNT_FOR_COMPAT as jint
341}
342
343#[cfg_or_panic(not(madsim))]
344#[unsafe(no_mangle)]
345extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNewStreamChunk<'a>(
346    env: EnvParam<'a>,
347    chunk: Pointer<'a, StreamChunk>,
348) -> Pointer<'static, JavaBindingIterator<'a>> {
349    execute_and_catch(env, move |_env| {
350        let iter = JavaBindingIterator {
351            inner: JavaBindingIteratorInner::from_chunk(chunk.as_ref()),
352            cursor: None,
353            class_cache: Default::default(),
354        };
355        Ok(iter.into())
356    })
357}
358
359#[cfg_or_panic(not(madsim))]
360#[unsafe(no_mangle)]
361extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNext<'a>(
362    env: EnvParam<'a>,
363    mut pointer: Pointer<'a, JavaBindingIterator<'a>>,
364) -> jboolean {
365    execute_and_catch(env, move |_env| {
366        let iter = pointer.as_mut();
367        match &mut iter.inner {
368            JavaBindingIteratorInner::Hummock(hummock_iter) => {
369                match JAVA_BINDING_ASYNC_RUNTIME.block_on(hummock_iter.try_next())? {
370                    None => {
371                        iter.cursor = None;
372                        Ok(JNI_FALSE)
373                    }
374                    Some((key, row)) => {
375                        iter.cursor = Some(RowCursor {
376                            row,
377                            extra: RowExtra::Key(key),
378                        });
379                        Ok(JNI_TRUE)
380                    }
381                }
382            }
383            JavaBindingIteratorInner::StreamChunk(stream_chunk_iter) => {
384                match stream_chunk_iter.next() {
385                    None => {
386                        iter.cursor = None;
387                        Ok(JNI_FALSE)
388                    }
389                    Some((op, row)) => {
390                        iter.cursor = Some(RowCursor {
391                            row,
392                            extra: RowExtra::Op(op),
393                        });
394                        Ok(JNI_TRUE)
395                    }
396                }
397            }
398        }
399    })
400}
401
402#[unsafe(no_mangle)]
403extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorClose<'a>(
404    _env: EnvParam<'a>,
405    pointer: OwnedPointer<JavaBindingIterator<'a>>,
406) {
407    pointer.release()
408}
409
410#[unsafe(no_mangle)]
411extern "system" fn Java_com_risingwave_java_binding_Binding_newStreamChunkFromPayload<'a>(
412    env: EnvParam<'a>,
413    stream_chunk_payload: JByteArray<'a>,
414) -> Pointer<'static, StreamChunk> {
415    execute_and_catch(env, move |env| {
416        let prost_stream_chumk =
417            Message::decode(to_guarded_slice(&stream_chunk_payload, env)?.deref())?;
418        Ok(StreamChunk::from_protobuf(&prost_stream_chumk)?.into())
419    })
420}
421
422#[unsafe(no_mangle)]
423extern "system" fn Java_com_risingwave_java_binding_Binding_newStreamChunkFromPretty<'a>(
424    env: EnvParam<'a>,
425    str: JString<'a>,
426) -> Pointer<'static, StreamChunk> {
427    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
428        Ok(StreamChunk::from_pretty(env.get_string(&str)?.to_str().unwrap()).into())
429    })
430}
431
432#[unsafe(no_mangle)]
433extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkClose(
434    _env: EnvParam<'_>,
435    chunk: OwnedPointer<StreamChunk>,
436) {
437    chunk.release()
438}
439
440#[unsafe(no_mangle)]
441extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetKey<'a>(
442    env: EnvParam<'a>,
443    pointer: Pointer<'a, JavaBindingIterator<'a>>,
444) -> JByteArray<'a> {
445    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
446        Ok(env.byte_array_from_slice(
447            pointer
448                .as_ref()
449                .cursor
450                .as_ref()
451                .expect("should exists when call get key")
452                .extra
453                .as_key()
454                .as_ref(),
455        )?)
456    })
457}
458
459#[unsafe(no_mangle)]
460extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetOp<'a>(
461    env: EnvParam<'a>,
462    pointer: Pointer<'a, JavaBindingIterator<'a>>,
463) -> jint {
464    execute_and_catch(env, move |_env| {
465        Ok(pointer
466            .as_ref()
467            .cursor
468            .as_ref()
469            .expect("should exist when call get op")
470            .extra
471            .as_op() as jint)
472    })
473}
474
475#[unsafe(no_mangle)]
476extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorIsNull<'a>(
477    env: EnvParam<'a>,
478    pointer: Pointer<'a, JavaBindingIterator<'a>>,
479    idx: jint,
480) -> jboolean {
481    execute_and_catch(env, move |_env| {
482        Ok(pointer.as_ref().datum_at(idx as usize).is_none() as jboolean)
483    })
484}
485
486#[unsafe(no_mangle)]
487extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt16Value<'a>(
488    env: EnvParam<'a>,
489    pointer: Pointer<'a, JavaBindingIterator<'a>>,
490    idx: jint,
491) -> jshort {
492    execute_and_catch(env, move |_env| {
493        Ok(pointer
494            .as_ref()
495            .datum_at(idx as usize)
496            .unwrap()
497            .into_int16())
498    })
499}
500
501#[unsafe(no_mangle)]
502extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt32Value<'a>(
503    env: EnvParam<'a>,
504    pointer: Pointer<'a, JavaBindingIterator<'a>>,
505    idx: jint,
506) -> jint {
507    execute_and_catch(env, move |_env| {
508        Ok(pointer
509            .as_ref()
510            .datum_at(idx as usize)
511            .unwrap()
512            .into_int32())
513    })
514}
515
516#[unsafe(no_mangle)]
517extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt64Value<'a>(
518    env: EnvParam<'a>,
519    pointer: Pointer<'a, JavaBindingIterator<'a>>,
520    idx: jint,
521) -> jlong {
522    execute_and_catch(env, move |_env| {
523        Ok(pointer
524            .as_ref()
525            .datum_at(idx as usize)
526            .unwrap()
527            .into_int64())
528    })
529}
530
531#[unsafe(no_mangle)]
532extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetFloatValue<'a>(
533    env: EnvParam<'a>,
534    pointer: Pointer<'a, JavaBindingIterator<'a>>,
535    idx: jint,
536) -> jfloat {
537    execute_and_catch(env, move |_env| {
538        Ok(pointer
539            .as_ref()
540            .datum_at(idx as usize)
541            .unwrap()
542            .into_float32()
543            .into())
544    })
545}
546
547#[unsafe(no_mangle)]
548extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDoubleValue<'a>(
549    env: EnvParam<'a>,
550    pointer: Pointer<'a, JavaBindingIterator<'a>>,
551    idx: jint,
552) -> jdouble {
553    execute_and_catch(env, move |_env| {
554        Ok(pointer
555            .as_ref()
556            .datum_at(idx as usize)
557            .unwrap()
558            .into_float64()
559            .into())
560    })
561}
562
563#[unsafe(no_mangle)]
564extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetBooleanValue<'a>(
565    env: EnvParam<'a>,
566    pointer: Pointer<'a, JavaBindingIterator<'a>>,
567    idx: jint,
568) -> jboolean {
569    execute_and_catch(env, move |_env| {
570        Ok(pointer.as_ref().datum_at(idx as usize).unwrap().into_bool() as jboolean)
571    })
572}
573
574#[unsafe(no_mangle)]
575extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetStringValue<'a>(
576    env: EnvParam<'a>,
577    pointer: Pointer<'a, JavaBindingIterator<'a>>,
578    idx: jint,
579) -> JString<'a> {
580    execute_and_catch(env, move |env: &mut EnvParam<'a>| {
581        Ok(env.new_string(pointer.as_ref().datum_at(idx as usize).unwrap().into_utf8())?)
582    })
583}
584
585#[unsafe(no_mangle)]
586extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetIntervalValue<'a>(
587    env: EnvParam<'a>,
588    pointer: Pointer<'a, JavaBindingIterator<'a>>,
589    idx: jint,
590) -> JString<'a> {
591    execute_and_catch(env, move |env: &mut EnvParam<'a>| {
592        let interval = pointer
593            .as_ref()
594            .datum_at(idx as usize)
595            .unwrap()
596            .into_interval()
597            .as_iso_8601();
598        Ok(env.new_string(interval)?)
599    })
600}
601
602#[unsafe(no_mangle)]
603extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetJsonbValue<'a>(
604    env: EnvParam<'a>,
605    pointer: Pointer<'a, JavaBindingIterator<'a>>,
606    idx: jint,
607) -> JString<'a> {
608    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
609        let jsonb = pointer
610            .as_ref()
611            .datum_at(idx as usize)
612            .unwrap()
613            .into_jsonb()
614            .to_string();
615        Ok(env.new_string(jsonb)?)
616    })
617}
618
619#[unsafe(no_mangle)]
620extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue<'a>(
621    env: EnvParam<'a>,
622    pointer: Pointer<'a, JavaBindingIterator<'a>>,
623    idx: jint,
624) -> JObject<'a> {
625    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
626        let value = pointer
627            .as_ref()
628            .datum_at(idx as usize)
629            .unwrap()
630            .into_timestamp();
631
632        let sig = gen_jni_sig!(java.time.LocalDateTime of(int year, int month, int dayOfMonth, int hour, int minute, int second, int nanoOfSecond));
633
634        let (timestamp_class_ref, constructor) = pointer
635            .as_ref()
636            .class_cache
637            .timestamp_ctor
638            .get_or_try_init(|| {
639                let cls = env.find_class(gen_class_name!(java.time.LocalDateTime))?;
640                let init_method = env.get_static_method_id(&cls, "of", sig)?;
641                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
642            })?;
643        unsafe {
644            let JValueOwned::Object(timestamp_obj) = env.call_static_method_unchecked(
645                <&JClass<'_>>::from(timestamp_class_ref.as_obj()),
646                *constructor,
647                ReturnType::Object,
648                &[
649                    jvalue { i: value.0.year() },
650                    jvalue {
651                        i: value.0.month() as i32,
652                    },
653                    jvalue {
654                        i: value.0.day() as i32,
655                    },
656                    jvalue {
657                        i: value.0.hour() as i32,
658                    },
659                    jvalue {
660                        i: value.0.minute() as i32,
661                    },
662                    jvalue {
663                        i: value.0.second() as i32,
664                    },
665                    jvalue {
666                        i: value.0.nanosecond() as i32,
667                    },
668                ],
669            )?
670            else {
671                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
672                    name: "of".to_owned(),
673                    sig: sig.into(),
674                }));
675            };
676            Ok(timestamp_obj)
677        }
678    })
679}
680
681#[unsafe(no_mangle)]
682extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestamptzValue<'a>(
683    env: EnvParam<'a>,
684    pointer: Pointer<'a, JavaBindingIterator<'a>>,
685    idx: jint,
686) -> JObject<'a> {
687    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
688        let value = pointer
689            .as_ref()
690            .datum_at(idx as usize)
691            .unwrap()
692            .into_timestamptz();
693
694        let instant_sig =
695            gen_jni_sig!(java.time.Instant ofEpochSecond(long epochSecond, long nanoAdjustment));
696
697        let (instant_class_ref, instant_constructor) = pointer
698            .as_ref()
699            .class_cache
700            .instant_ctor
701            .get_or_try_init(|| {
702                let cls = env.find_class(gen_class_name!(java.time.Instant))?;
703                let init_method = env.get_static_method_id(&cls, "ofEpochSecond", instant_sig)?;
704                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
705            })?;
706        let instant_obj = unsafe {
707            let JValueOwned::Object(instant_obj) = env.call_static_method_unchecked(
708                <&JClass<'_>>::from(instant_class_ref.as_obj()),
709                *instant_constructor,
710                ReturnType::Object,
711                &[
712                    jvalue {
713                        j: value.timestamp(),
714                    },
715                    jvalue {
716                        j: value.timestamp_subsec_nanos() as i64,
717                    },
718                ],
719            )?
720            else {
721                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
722                    name: "ofEpochSecond".to_owned(),
723                    sig: instant_sig.into(),
724                }));
725            };
726            instant_obj
727        };
728
729        let utc_ref = pointer.as_ref().class_cache.utc.get_or_try_init(|| {
730            let cls = env.find_class(gen_class_name!(java.time.ZoneOffset))?;
731            let utc = env
732                .get_static_field(&cls, "UTC", gen_jni_type_sig!(java.time.ZoneOffset))?
733                .l()?;
734            env.new_global_ref(utc)
735        })?;
736
737        let sig = gen_jni_sig!(java.time.OffsetDateTime ofInstant(java.time.Instant instant, java.time.ZoneId zone));
738
739        let (timestamptz_class_ref, constructor) = pointer
740            .as_ref()
741            .class_cache
742            .timestamptz_ctor
743            .get_or_try_init(|| {
744                let cls = env.find_class(gen_class_name!(java.time.OffsetDateTime))?;
745                let init_method = env.get_static_method_id(&cls, "ofInstant", sig)?;
746                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
747            })?;
748        unsafe {
749            let JValueOwned::Object(timestamptz_obj) = env.call_static_method_unchecked(
750                <&JClass<'_>>::from(timestamptz_class_ref.as_obj()),
751                *constructor,
752                ReturnType::Object,
753                &[
754                    jvalue {
755                        l: instant_obj.as_raw(),
756                    },
757                    jvalue {
758                        l: utc_ref.as_obj().as_raw(),
759                    },
760                ],
761            )?
762            else {
763                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
764                    name: "ofInstant".to_owned(),
765                    sig: sig.into(),
766                }));
767            };
768            Ok(timestamptz_obj)
769        }
770    })
771}
772
773#[unsafe(no_mangle)]
774extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDecimalValue<'a>(
775    env: EnvParam<'a>,
776    pointer: Pointer<'a, JavaBindingIterator<'a>>,
777    idx: jint,
778) -> JObject<'a> {
779    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
780        let decimal_value = pointer
781            .as_ref()
782            .datum_at(idx as usize)
783            .unwrap()
784            .into_decimal();
785
786        match decimal_value {
787            Decimal::NaN | Decimal::NegativeInf | Decimal::PositiveInf => {
788                return Ok(JObject::null());
789            }
790            Decimal::Normalized(_) => {}
791        };
792
793        let value = decimal_value.to_string();
794        let string_value = env.new_string(value)?;
795        let (decimal_class_ref, constructor) = pointer
796            .as_ref()
797            .class_cache
798            .big_decimal_ctor
799            .get_or_try_init(|| {
800                let cls = env.find_class("java/math/BigDecimal")?;
801                let init_method = env.get_method_id(&cls, "<init>", "(Ljava/lang/String;)V")?;
802                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
803            })?;
804        unsafe {
805            let decimal_class = <&JClass<'_>>::from(decimal_class_ref.as_obj());
806            let date_obj = env.new_object_unchecked(
807                decimal_class,
808                *constructor,
809                &[jvalue {
810                    l: string_value.into_raw(),
811                }],
812            )?;
813            Ok(date_obj)
814        }
815    })
816}
817
818#[unsafe(no_mangle)]
819extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDateValue<'a>(
820    env: EnvParam<'a>,
821    pointer: Pointer<'a, JavaBindingIterator<'a>>,
822    idx: jint,
823) -> JObject<'a> {
824    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
825        let value = pointer.as_ref().datum_at(idx as usize).unwrap().into_date();
826        let epoch_days = (value.0 - NaiveDateTime::UNIX_EPOCH.date()).num_days();
827
828        let sig = gen_jni_sig!(java.time.LocalDate ofEpochDay(long));
829
830        let (date_class_ref, constructor) =
831            pointer.as_ref().class_cache.date_ctor.get_or_try_init(|| {
832                let cls = env.find_class(gen_class_name!(java.time.LocalDate))?;
833                let init_method = env.get_static_method_id(&cls, "ofEpochDay", sig)?;
834                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
835            })?;
836        unsafe {
837            let JValueOwned::Object(date_obj) = env.call_static_method_unchecked(
838                <&JClass<'_>>::from(date_class_ref.as_obj()),
839                *constructor,
840                ReturnType::Object,
841                &[jvalue { j: epoch_days }],
842            )?
843            else {
844                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
845                    name: "ofEpochDay".to_owned(),
846                    sig: sig.into(),
847                }));
848            };
849            Ok(date_obj)
850        }
851    })
852}
853
854#[unsafe(no_mangle)]
855extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimeValue<'a>(
856    env: EnvParam<'a>,
857    pointer: Pointer<'a, JavaBindingIterator<'a>>,
858    idx: jint,
859) -> JObject<'a> {
860    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
861        let value = pointer.as_ref().datum_at(idx as usize).unwrap().into_time();
862
863        let sig = gen_jni_sig!(java.time.LocalTime of(int hour, int minute, int second, int nanoOfSecond));
864
865        let (time_class_ref, constructor) =
866            pointer.as_ref().class_cache.time_ctor.get_or_try_init(|| {
867                let cls = env.find_class(gen_class_name!(java.time.LocalTime))?;
868                let init_method = env.get_static_method_id(&cls, "of", sig)?;
869                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
870            })?;
871        unsafe {
872            let JValueOwned::Object(time_obj) = env.call_static_method_unchecked(
873                <&JClass<'_>>::from(time_class_ref.as_obj()),
874                *constructor,
875                ReturnType::Object,
876                &[
877                    jvalue {
878                        i: value.0.hour() as i32,
879                    },
880                    jvalue {
881                        i: value.0.minute() as i32,
882                    },
883                    jvalue {
884                        i: value.0.second() as i32,
885                    },
886                    jvalue {
887                        i: value.0.nanosecond() as i32,
888                    },
889                ],
890            )?
891            else {
892                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
893                    name: "of".to_owned(),
894                    sig: sig.into(),
895                }));
896            };
897            Ok(time_obj)
898        }
899    })
900}
901
902#[unsafe(no_mangle)]
903extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetByteaValue<'a>(
904    env: EnvParam<'a>,
905    pointer: Pointer<'a, JavaBindingIterator<'a>>,
906    idx: jint,
907) -> JByteArray<'a> {
908    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
909        let bytes = pointer
910            .as_ref()
911            .datum_at(idx as usize)
912            .unwrap()
913            .into_bytea();
914        Ok(env.byte_array_from_slice(bytes)?)
915    })
916}
917
918#[unsafe(no_mangle)]
919extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValue<'a>(
920    env: EnvParam<'a>,
921    pointer: Pointer<'a, JavaBindingIterator<'a>>,
922    idx: jint,
923    class: JClass<'a>,
924) -> JObject<'a> {
925    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
926        let elems = pointer
927            .as_ref()
928            .datum_at(idx as usize)
929            .unwrap()
930            .into_list()
931            .iter();
932
933        // convert the Rust elements to a Java object array (Object[])
934        let jarray = env.new_object_array(elems.len() as jsize, &class, JObject::null())?;
935
936        for (i, ele) in elems.enumerate() {
937            let index = i as jsize;
938            match ele {
939                None => env.set_object_array_element(&jarray, i as jsize, JObject::null())?,
940                Some(val) => match val {
941                    ScalarRefImpl::Int16(v) => {
942                        let o = call_static_method!(
943                            env,
944                            {Short},
945                            {Short	valueOf(short s)},
946                            v
947                        )?;
948                        env.set_object_array_element(&jarray, index, &o)?;
949                    }
950                    ScalarRefImpl::Int32(v) => {
951                        let o = call_static_method!(
952                            env,
953                            {Integer},
954                            {Integer	valueOf(int i)},
955                            v
956                        )?;
957                        env.set_object_array_element(&jarray, index, &o)?;
958                    }
959                    ScalarRefImpl::Int64(v) => {
960                        let o = call_static_method!(
961                            env,
962                            {Long},
963                            {Long	valueOf(long l)},
964                            v
965                        )?;
966                        env.set_object_array_element(&jarray, index, &o)?;
967                    }
968                    ScalarRefImpl::Float32(v) => {
969                        let o = call_static_method!(
970                            env,
971                            {Float},
972                            {Float	valueOf(float f)},
973                            v.into_inner()
974                        )?;
975                        env.set_object_array_element(&jarray, index, &o)?;
976                    }
977                    ScalarRefImpl::Float64(v) => {
978                        let o = call_static_method!(
979                            env,
980                            {Double},
981                            {Double	valueOf(double d)},
982                            v.into_inner()
983                        )?;
984                        env.set_object_array_element(&jarray, index, &o)?;
985                    }
986                    ScalarRefImpl::Utf8(v) => {
987                        let obj = env.new_string(v)?;
988                        env.set_object_array_element(&jarray, index, obj)?
989                    }
990                    _ => env.set_object_array_element(&jarray, index, JObject::null())?,
991                },
992            }
993        }
994        let output = unsafe { JObject::from_raw(jarray.into_raw()) };
995        Ok(output)
996    })
997}
998
999pub type JniSenderType<T> = Sender<anyhow::Result<T>>;
1000pub type JniReceiverType<T> = Receiver<T>;
1001
1002/// Send messages to the channel received by `CdcSplitReader`.
1003/// If msg is null, just check whether the channel is closed.
1004/// Return true if sending is successful, otherwise, return false so that caller can stop
1005/// gracefully.
1006#[unsafe(no_mangle)]
1007extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>(
1008    env: EnvParam<'a>,
1009    channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
1010    msg: JByteArray<'a>,
1011) -> jboolean {
1012    execute_and_catch(env, move |env| {
1013        // If msg is null means just check whether channel is closed.
1014        if msg.is_null() {
1015            if channel.as_ref().is_closed() {
1016                return Ok(JNI_FALSE);
1017            } else {
1018                return Ok(JNI_TRUE);
1019            }
1020        }
1021
1022        let get_event_stream_response: GetEventStreamResponse =
1023            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1024
1025        match channel
1026            .as_ref()
1027            .blocking_send(Ok(get_event_stream_response))
1028        {
1029            Ok(_) => Ok(JNI_TRUE),
1030            Err(e) => {
1031                tracing::info!(error = %e.as_report(), "send error");
1032                Ok(JNI_FALSE)
1033            }
1034        }
1035    })
1036}
1037
1038#[unsafe(no_mangle)]
1039extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceErrorToChannel<'a>(
1040    env: EnvParam<'a>,
1041    channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
1042    msg: JString<'a>,
1043) -> jboolean {
1044    execute_and_catch(env, move |env| {
1045        let ret = env.get_string(&msg);
1046        match ret {
1047            Ok(str) => {
1048                let err_msg: String = str.into();
1049                match channel.as_ref().blocking_send(Err(anyhow!(err_msg))) {
1050                    Ok(_) => Ok(JNI_TRUE),
1051                    Err(e) => {
1052                        tracing::info!(error = ?e.as_report(), "send error");
1053                        Ok(JNI_FALSE)
1054                    }
1055                }
1056            }
1057            Err(err) => {
1058                if msg.is_null() {
1059                    tracing::warn!("source error message is null");
1060                    Ok(JNI_TRUE)
1061                } else {
1062                    tracing::error!(error = ?err.as_report(), "source error message should be a java string");
1063                    Ok(JNI_FALSE)
1064                }
1065            }
1066        }
1067    })
1068}
1069
1070#[unsafe(no_mangle)]
1071extern "system" fn Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose(
1072    _env: EnvParam<'_>,
1073    channel: OwnedPointer<JniSenderType<GetEventStreamResponse>>,
1074) {
1075    channel.release();
1076}
1077
1078pub enum JniSinkWriterStreamRequest {
1079    PbRequest(SinkWriterStreamRequest),
1080    Chunk {
1081        epoch: u64,
1082        batch_id: u64,
1083        chunk: StreamChunk,
1084    },
1085}
1086
1087impl From<SinkWriterStreamRequest> for JniSinkWriterStreamRequest {
1088    fn from(value: SinkWriterStreamRequest) -> Self {
1089        Self::PbRequest(value)
1090    }
1091}
1092
1093#[unsafe(no_mangle)]
1094pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel<
1095    'a,
1096>(
1097    env: EnvParam<'a>,
1098    mut channel: Pointer<'a, JniReceiverType<JniSinkWriterStreamRequest>>,
1099) -> JObject<'a> {
1100    execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
1101        Some(msg) => {
1102            let obj = match msg {
1103                JniSinkWriterStreamRequest::PbRequest(request) => {
1104                    let bytes = env.byte_array_from_slice(&Message::encode_to_vec(&request))?;
1105                    call_static_method!(
1106                        env,
1107                        {com.risingwave.java.binding.JniSinkWriterStreamRequest},
1108                        {com.risingwave.java.binding.JniSinkWriterStreamRequest fromSerializedPayload(byte[] payload)},
1109                        &JObject::from(bytes)
1110                    )?
1111                }
1112                JniSinkWriterStreamRequest::Chunk {
1113                    epoch,
1114                    batch_id,
1115                    chunk,
1116                } => {
1117                    let pointer = Box::into_raw(Box::new(chunk));
1118                    call_static_method!(
1119                        env,
1120                        {com.risingwave.java.binding.JniSinkWriterStreamRequest},
1121                        {com.risingwave.java.binding.JniSinkWriterStreamRequest fromStreamChunkOwnedPointer(long pointer, long epoch, long batchId)},
1122                        pointer as u64, epoch, batch_id
1123                    )
1124                    .inspect_err(|_| unsafe {
1125                        // release the stream chunk on err
1126                        drop(Box::from_raw(pointer));
1127                    })?
1128                }
1129            };
1130            Ok(obj)
1131        }
1132        None => Ok(JObject::null()),
1133    })
1134}
1135
1136#[unsafe(no_mangle)]
1137pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterResponseToChannel<
1138    'a,
1139>(
1140    env: EnvParam<'a>,
1141    channel: Pointer<'a, JniSenderType<SinkWriterStreamResponse>>,
1142    msg: JByteArray<'a>,
1143) -> jboolean {
1144    execute_and_catch(env, move |env| {
1145        let sink_writer_stream_response: SinkWriterStreamResponse =
1146            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1147
1148        match channel
1149            .as_ref()
1150            .blocking_send(Ok(sink_writer_stream_response))
1151        {
1152            Ok(_) => Ok(JNI_TRUE),
1153            Err(e) => {
1154                tracing::info!(error = ?e.as_report(), "send error");
1155                Ok(JNI_FALSE)
1156            }
1157        }
1158    })
1159}
1160
1161#[unsafe(no_mangle)]
1162pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterErrorToChannel<'a>(
1163    env: EnvParam<'a>,
1164    channel: Pointer<'a, Sender<anyhow::Result<SinkWriterStreamResponse>>>,
1165    msg: JString<'a>,
1166) -> jboolean {
1167    execute_and_catch(env, move |env| {
1168        let ret = env.get_string(&msg);
1169        match ret {
1170            Ok(str) => {
1171                let err_msg: String = str.into();
1172                match channel.as_ref().blocking_send(Err(anyhow!(err_msg))) {
1173                    Ok(_) => Ok(JNI_TRUE),
1174                    Err(e) => {
1175                        tracing::info!(error = ?e.as_report(), "send error");
1176                        Ok(JNI_FALSE)
1177                    }
1178                }
1179            }
1180            Err(err) => {
1181                if msg.is_null() {
1182                    tracing::warn!("sink error message is null");
1183                    Ok(JNI_TRUE)
1184                } else {
1185                    tracing::error!(error = ?err.as_report(), "sink error message should be a java string");
1186                    Ok(JNI_FALSE)
1187                }
1188            }
1189        }
1190    })
1191}
1192
1193#[unsafe(no_mangle)]
1194pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkCoordinatorRequestFromChannel<
1195    'a,
1196>(
1197    env: EnvParam<'a>,
1198    mut channel: Pointer<'a, JniReceiverType<SinkCoordinatorStreamRequest>>,
1199) -> JByteArray<'a> {
1200    execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
1201        Some(msg) => {
1202            let bytes = env
1203                .byte_array_from_slice(&Message::encode_to_vec(&msg))
1204                .unwrap();
1205            Ok(bytes)
1206        }
1207        None => Ok(JObject::null().into()),
1208    })
1209}
1210
1211#[unsafe(no_mangle)]
1212pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordinatorResponseToChannel<
1213    'a,
1214>(
1215    env: EnvParam<'a>,
1216    channel: Pointer<'a, JniSenderType<SinkCoordinatorStreamResponse>>,
1217    msg: JByteArray<'a>,
1218) -> jboolean {
1219    execute_and_catch(env, move |env| {
1220        let sink_coordinator_stream_response: SinkCoordinatorStreamResponse =
1221            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1222
1223        match channel
1224            .as_ref()
1225            .blocking_send(Ok(sink_coordinator_stream_response))
1226        {
1227            Ok(_) => Ok(JNI_TRUE),
1228            Err(e) => {
1229                tracing::info!(error = ?e.as_report(), "send error");
1230                Ok(JNI_FALSE)
1231            }
1232        }
1233    })
1234}
1235
1236#[cfg(test)]
1237mod tests {
1238    use risingwave_common::types::Timestamptz;
1239
1240    /// make sure that the [`ScalarRefImpl::Int64`] received by
1241    /// [`Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue`]
1242    /// is of type [`DataType::Timestamptz`] stored in microseconds
1243    #[test]
1244    fn test_timestamptz_to_i64() {
1245        assert_eq!(
1246            "2023-06-01 09:45:00+08:00".parse::<Timestamptz>().unwrap(),
1247            Timestamptz::from_micros(1_685_583_900_000_000)
1248        );
1249    }
1250}