risingwave_jni_core/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(error_generic_member_access)]
16#![feature(once_cell_try)]
17#![feature(type_alias_impl_trait)]
18#![feature(try_blocks)]
19
20pub mod jvm_runtime;
21mod macros;
22mod tracing_slf4j;
23
24use std::backtrace::Backtrace;
25use std::marker::PhantomData;
26use std::ops::{Deref, DerefMut};
27use std::slice::from_raw_parts;
28use std::sync::{LazyLock, OnceLock};
29
30use anyhow::anyhow;
31use bytes::Bytes;
32use cfg_or_panic::cfg_or_panic;
33use chrono::{DateTime, Datelike, Timelike};
34use futures::TryStreamExt;
35use futures::stream::BoxStream;
36use jni::JNIEnv;
37use jni::objects::{
38    AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString,
39    JValueOwned, ReleaseMode,
40};
41use jni::signature::ReturnType;
42use jni::sys::{
43    JNI_FALSE, JNI_TRUE, jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue,
44};
45pub use paste::paste;
46use prost::{DecodeError, Message};
47use risingwave_common::array::{ArrayError, StreamChunk};
48use risingwave_common::hash::VirtualNode;
49use risingwave_common::row::{OwnedRow, Row};
50use risingwave_common::test_prelude::StreamChunkTestExt;
51use risingwave_common::types::{Decimal, ScalarRefImpl};
52use risingwave_common::util::panic::rw_catch_unwind;
53use risingwave_pb::connector_service::{
54    GetEventStreamResponse, SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse,
55    SinkWriterStreamRequest, SinkWriterStreamResponse,
56};
57use risingwave_pb::data::Op;
58use thiserror::Error;
59use thiserror_ext::AsReport;
60use tokio::runtime::Runtime;
61use tokio::sync::mpsc::{Receiver, Sender};
62use tracing_slf4j::*;
63
64pub static JAVA_BINDING_ASYNC_RUNTIME: LazyLock<Runtime> =
65    LazyLock::new(|| tokio::runtime::Runtime::new().unwrap());
66
67#[derive(Error, Debug)]
68pub enum BindingError {
69    #[error("JniError {error}")]
70    Jni {
71        #[from]
72        error: jni::errors::Error,
73        backtrace: Backtrace,
74    },
75
76    #[error("StorageError {error}")]
77    Storage {
78        #[from]
79        error: anyhow::Error,
80        backtrace: Backtrace,
81    },
82
83    #[error("DecodeError {error}")]
84    Decode {
85        #[from]
86        error: DecodeError,
87        backtrace: Backtrace,
88    },
89
90    #[error("StreamChunkArrayError {error}")]
91    StreamChunkArray {
92        #[from]
93        error: ArrayError,
94        backtrace: Backtrace,
95    },
96}
97
98type Result<T> = std::result::Result<T, BindingError>;
99
100pub fn to_guarded_slice<'array, 'env>(
101    array: &'array JByteArray<'env>,
102    env: &'array mut JNIEnv<'env>,
103) -> Result<SliceGuard<'env, 'array>> {
104    unsafe {
105        let array = env.get_array_elements(array, ReleaseMode::NoCopyBack)?;
106        let slice = from_raw_parts(array.as_ptr() as *mut u8, array.len());
107
108        Ok(SliceGuard {
109            _array: array,
110            slice,
111        })
112    }
113}
114
115/// Wrapper around `&[u8]` derived from `jbyteArray` to prevent it from being auto-released.
116pub struct SliceGuard<'env, 'array> {
117    _array: AutoElements<'env, 'env, 'array, jbyte>,
118    slice: &'array [u8],
119}
120
121impl Deref for SliceGuard<'_, '_> {
122    type Target = [u8];
123
124    fn deref(&self) -> &Self::Target {
125        self.slice
126    }
127}
128
129#[repr(transparent)]
130pub struct Pointer<'a, T> {
131    pointer: jlong,
132    _phantom: PhantomData<&'a T>,
133}
134
135impl<T> Default for Pointer<'_, T> {
136    fn default() -> Self {
137        Self {
138            pointer: 0,
139            _phantom: Default::default(),
140        }
141    }
142}
143
144impl<T> From<T> for Pointer<'static, T> {
145    fn from(value: T) -> Self {
146        Pointer {
147            pointer: Box::into_raw(Box::new(value)) as jlong,
148            _phantom: PhantomData,
149        }
150    }
151}
152
153impl<'a, T> Pointer<'a, T> {
154    fn as_ref(&self) -> &'a T {
155        assert!(self.pointer != 0);
156        unsafe { &*(self.pointer as *const T) }
157    }
158
159    fn as_mut(&mut self) -> &'a mut T {
160        assert!(self.pointer != 0);
161        unsafe { &mut *(self.pointer as *mut T) }
162    }
163}
164
165/// A pointer that owns the object it points to.
166///
167/// Note that dropping an `OwnedPointer` does not release the object.
168/// Instead, you should call [`OwnedPointer::release`] manually.
169pub type OwnedPointer<T> = Pointer<'static, T>;
170
171impl<T> OwnedPointer<T> {
172    /// Consume `self` and return the pointer value. Used for passing to JNI.
173    pub fn into_pointer(self) -> jlong {
174        self.pointer
175    }
176
177    /// Release the object behind the pointer.
178    fn release(self) {
179        tracing::debug!(
180            type_name = std::any::type_name::<T>(),
181            address = %format_args!("{:x}", self.pointer),
182            "release jni OwnedPointer"
183        );
184        assert!(self.pointer != 0);
185        unsafe { drop(Box::from_raw(self.pointer as *mut T)) }
186    }
187}
188
189/// In most Jni interfaces, the first parameter is `JNIEnv`, and the second parameter is `JClass`.
190/// This struct simply encapsulates the two common parameters into a single struct for simplicity.
191#[repr(C)]
192pub struct EnvParam<'a> {
193    env: JNIEnv<'a>,
194    class: JClass<'a>,
195}
196
197impl<'a> Deref for EnvParam<'a> {
198    type Target = JNIEnv<'a>;
199
200    fn deref(&self) -> &Self::Target {
201        &self.env
202    }
203}
204
205impl DerefMut for EnvParam<'_> {
206    fn deref_mut(&mut self) -> &mut Self::Target {
207        &mut self.env
208    }
209}
210
211impl<'a> EnvParam<'a> {
212    pub fn get_class(&self) -> &JClass<'a> {
213        &self.class
214    }
215}
216
217pub fn execute_and_catch<'env, F, Ret>(mut env: EnvParam<'env>, inner: F) -> Ret
218where
219    F: FnOnce(&mut EnvParam<'env>) -> Result<Ret>,
220    Ret: Default + 'env,
221{
222    match rw_catch_unwind(std::panic::AssertUnwindSafe(|| inner(&mut env))) {
223        Ok(Ok(ret)) => ret,
224        Ok(Err(e)) => {
225            match e {
226                BindingError::Jni {
227                    error: jni::errors::Error::JavaException,
228                    backtrace,
229                } => {
230                    tracing::error!("get JavaException thrown from: {:?}", backtrace);
231                    // the exception is already thrown. No need to throw again
232                }
233                _ => {
234                    env.throw(format!("get error while processing: {:?}", e.as_report()))
235                        .expect("should be able to throw");
236                }
237            }
238            Ret::default()
239        }
240        Err(e) => {
241            env.throw(format!("panic while processing: {:?}", e))
242                .expect("should be able to throw");
243            Ret::default()
244        }
245    }
246}
247
248#[derive(Default)]
249struct JavaClassMethodCache {
250    big_decimal_ctor: OnceLock<(GlobalRef, JMethodID)>,
251
252    timestamp_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
253    timestamptz_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
254    date_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
255    time_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
256    instant_ctor: OnceLock<(GlobalRef, JStaticMethodID)>,
257    utc: OnceLock<GlobalRef>,
258}
259
260mod opaque_type {
261    use super::*;
262    // TODO: may only return a RowRef
263    pub type StreamChunkRowIterator<'a> = impl Iterator<Item = (Op, OwnedRow)> + 'a;
264
265    impl<'a> JavaBindingIteratorInner<'a> {
266        #[define_opaque(StreamChunkRowIterator)]
267        pub(super) fn from_chunk(chunk: &'a StreamChunk) -> JavaBindingIteratorInner<'a> {
268            JavaBindingIteratorInner::StreamChunk(
269                chunk
270                    .rows()
271                    .map(|(op, row)| (op.to_protobuf(), row.to_owned_row())),
272            )
273        }
274    }
275}
276pub use opaque_type::StreamChunkRowIterator;
277pub type HummockJavaBindingIterator = BoxStream<'static, anyhow::Result<(Bytes, OwnedRow)>>;
278pub enum JavaBindingIteratorInner<'a> {
279    Hummock(HummockJavaBindingIterator),
280    StreamChunk(StreamChunkRowIterator<'a>),
281}
282
283enum RowExtra {
284    Op(Op),
285    Key(Bytes),
286}
287
288impl RowExtra {
289    fn as_op(&self) -> Op {
290        match self {
291            RowExtra::Op(op) => *op,
292            RowExtra::Key(_) => unreachable!("should be op"),
293        }
294    }
295
296    fn as_key(&self) -> &Bytes {
297        match self {
298            RowExtra::Key(key) => key,
299            RowExtra::Op(_) => unreachable!("should be key"),
300        }
301    }
302}
303
304struct RowCursor {
305    row: OwnedRow,
306    extra: RowExtra,
307}
308
309pub struct JavaBindingIterator<'a> {
310    inner: JavaBindingIteratorInner<'a>,
311    cursor: Option<RowCursor>,
312    class_cache: JavaClassMethodCache,
313}
314
315impl JavaBindingIterator<'static> {
316    pub fn new_hummock_iter(iter: HummockJavaBindingIterator) -> Self {
317        Self {
318            inner: JavaBindingIteratorInner::Hummock(iter),
319            cursor: None,
320            class_cache: Default::default(),
321        }
322    }
323}
324
325impl Deref for JavaBindingIterator<'_> {
326    type Target = OwnedRow;
327
328    fn deref(&self) -> &Self::Target {
329        &self
330            .cursor
331            .as_ref()
332            .expect("should exist when call row methods")
333            .row
334    }
335}
336
337#[unsafe(no_mangle)]
338extern "system" fn Java_com_risingwave_java_binding_Binding_defaultVnodeCount(
339    _env: EnvParam<'_>,
340) -> jint {
341    VirtualNode::COUNT_FOR_COMPAT as jint
342}
343
344#[cfg_or_panic(not(madsim))]
345#[unsafe(no_mangle)]
346extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNewStreamChunk<'a>(
347    env: EnvParam<'a>,
348    chunk: Pointer<'a, StreamChunk>,
349) -> Pointer<'static, JavaBindingIterator<'a>> {
350    execute_and_catch(env, move |_env| {
351        let iter = JavaBindingIterator {
352            inner: JavaBindingIteratorInner::from_chunk(chunk.as_ref()),
353            cursor: None,
354            class_cache: Default::default(),
355        };
356        Ok(iter.into())
357    })
358}
359
360#[cfg_or_panic(not(madsim))]
361#[unsafe(no_mangle)]
362extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNext<'a>(
363    env: EnvParam<'a>,
364    mut pointer: Pointer<'a, JavaBindingIterator<'a>>,
365) -> jboolean {
366    execute_and_catch(env, move |_env| {
367        let iter = pointer.as_mut();
368        match &mut iter.inner {
369            JavaBindingIteratorInner::Hummock(hummock_iter) => {
370                match JAVA_BINDING_ASYNC_RUNTIME.block_on(hummock_iter.try_next())? {
371                    None => {
372                        iter.cursor = None;
373                        Ok(JNI_FALSE)
374                    }
375                    Some((key, row)) => {
376                        iter.cursor = Some(RowCursor {
377                            row,
378                            extra: RowExtra::Key(key),
379                        });
380                        Ok(JNI_TRUE)
381                    }
382                }
383            }
384            JavaBindingIteratorInner::StreamChunk(stream_chunk_iter) => {
385                match stream_chunk_iter.next() {
386                    None => {
387                        iter.cursor = None;
388                        Ok(JNI_FALSE)
389                    }
390                    Some((op, row)) => {
391                        iter.cursor = Some(RowCursor {
392                            row,
393                            extra: RowExtra::Op(op),
394                        });
395                        Ok(JNI_TRUE)
396                    }
397                }
398            }
399        }
400    })
401}
402
403#[unsafe(no_mangle)]
404extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorClose<'a>(
405    _env: EnvParam<'a>,
406    pointer: OwnedPointer<JavaBindingIterator<'a>>,
407) {
408    pointer.release()
409}
410
411#[unsafe(no_mangle)]
412extern "system" fn Java_com_risingwave_java_binding_Binding_newStreamChunkFromPayload<'a>(
413    env: EnvParam<'a>,
414    stream_chunk_payload: JByteArray<'a>,
415) -> Pointer<'static, StreamChunk> {
416    execute_and_catch(env, move |env| {
417        let prost_stream_chumk =
418            Message::decode(to_guarded_slice(&stream_chunk_payload, env)?.deref())?;
419        Ok(StreamChunk::from_protobuf(&prost_stream_chumk)?.into())
420    })
421}
422
423#[unsafe(no_mangle)]
424extern "system" fn Java_com_risingwave_java_binding_Binding_newStreamChunkFromPretty<'a>(
425    env: EnvParam<'a>,
426    str: JString<'a>,
427) -> Pointer<'static, StreamChunk> {
428    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
429        Ok(StreamChunk::from_pretty(env.get_string(&str)?.to_str().unwrap()).into())
430    })
431}
432
433#[unsafe(no_mangle)]
434extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkClose(
435    _env: EnvParam<'_>,
436    chunk: OwnedPointer<StreamChunk>,
437) {
438    chunk.release()
439}
440
441#[unsafe(no_mangle)]
442extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetKey<'a>(
443    env: EnvParam<'a>,
444    pointer: Pointer<'a, JavaBindingIterator<'a>>,
445) -> JByteArray<'a> {
446    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
447        Ok(env.byte_array_from_slice(
448            pointer
449                .as_ref()
450                .cursor
451                .as_ref()
452                .expect("should exists when call get key")
453                .extra
454                .as_key()
455                .as_ref(),
456        )?)
457    })
458}
459
460#[unsafe(no_mangle)]
461extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetOp<'a>(
462    env: EnvParam<'a>,
463    pointer: Pointer<'a, JavaBindingIterator<'a>>,
464) -> jint {
465    execute_and_catch(env, move |_env| {
466        Ok(pointer
467            .as_ref()
468            .cursor
469            .as_ref()
470            .expect("should exist when call get op")
471            .extra
472            .as_op() as jint)
473    })
474}
475
476#[unsafe(no_mangle)]
477extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorIsNull<'a>(
478    env: EnvParam<'a>,
479    pointer: Pointer<'a, JavaBindingIterator<'a>>,
480    idx: jint,
481) -> jboolean {
482    execute_and_catch(env, move |_env| {
483        Ok(pointer.as_ref().datum_at(idx as usize).is_none() as jboolean)
484    })
485}
486
487#[unsafe(no_mangle)]
488extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt16Value<'a>(
489    env: EnvParam<'a>,
490    pointer: Pointer<'a, JavaBindingIterator<'a>>,
491    idx: jint,
492) -> jshort {
493    execute_and_catch(env, move |_env| {
494        Ok(pointer
495            .as_ref()
496            .datum_at(idx as usize)
497            .unwrap()
498            .into_int16())
499    })
500}
501
502#[unsafe(no_mangle)]
503extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt32Value<'a>(
504    env: EnvParam<'a>,
505    pointer: Pointer<'a, JavaBindingIterator<'a>>,
506    idx: jint,
507) -> jint {
508    execute_and_catch(env, move |_env| {
509        Ok(pointer
510            .as_ref()
511            .datum_at(idx as usize)
512            .unwrap()
513            .into_int32())
514    })
515}
516
517#[unsafe(no_mangle)]
518extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt64Value<'a>(
519    env: EnvParam<'a>,
520    pointer: Pointer<'a, JavaBindingIterator<'a>>,
521    idx: jint,
522) -> jlong {
523    execute_and_catch(env, move |_env| {
524        Ok(pointer
525            .as_ref()
526            .datum_at(idx as usize)
527            .unwrap()
528            .into_int64())
529    })
530}
531
532#[unsafe(no_mangle)]
533extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetFloatValue<'a>(
534    env: EnvParam<'a>,
535    pointer: Pointer<'a, JavaBindingIterator<'a>>,
536    idx: jint,
537) -> jfloat {
538    execute_and_catch(env, move |_env| {
539        Ok(pointer
540            .as_ref()
541            .datum_at(idx as usize)
542            .unwrap()
543            .into_float32()
544            .into())
545    })
546}
547
548#[unsafe(no_mangle)]
549extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDoubleValue<'a>(
550    env: EnvParam<'a>,
551    pointer: Pointer<'a, JavaBindingIterator<'a>>,
552    idx: jint,
553) -> jdouble {
554    execute_and_catch(env, move |_env| {
555        Ok(pointer
556            .as_ref()
557            .datum_at(idx as usize)
558            .unwrap()
559            .into_float64()
560            .into())
561    })
562}
563
564#[unsafe(no_mangle)]
565extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetBooleanValue<'a>(
566    env: EnvParam<'a>,
567    pointer: Pointer<'a, JavaBindingIterator<'a>>,
568    idx: jint,
569) -> jboolean {
570    execute_and_catch(env, move |_env| {
571        Ok(pointer.as_ref().datum_at(idx as usize).unwrap().into_bool() as jboolean)
572    })
573}
574
575#[unsafe(no_mangle)]
576extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetStringValue<'a>(
577    env: EnvParam<'a>,
578    pointer: Pointer<'a, JavaBindingIterator<'a>>,
579    idx: jint,
580) -> JString<'a> {
581    execute_and_catch(env, move |env: &mut EnvParam<'a>| {
582        Ok(env.new_string(pointer.as_ref().datum_at(idx as usize).unwrap().into_utf8())?)
583    })
584}
585
586#[unsafe(no_mangle)]
587extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetIntervalValue<'a>(
588    env: EnvParam<'a>,
589    pointer: Pointer<'a, JavaBindingIterator<'a>>,
590    idx: jint,
591) -> JString<'a> {
592    execute_and_catch(env, move |env: &mut EnvParam<'a>| {
593        let interval = pointer
594            .as_ref()
595            .datum_at(idx as usize)
596            .unwrap()
597            .into_interval()
598            .as_iso_8601();
599        Ok(env.new_string(interval)?)
600    })
601}
602
603#[unsafe(no_mangle)]
604extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetJsonbValue<'a>(
605    env: EnvParam<'a>,
606    pointer: Pointer<'a, JavaBindingIterator<'a>>,
607    idx: jint,
608) -> JString<'a> {
609    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
610        let jsonb = pointer
611            .as_ref()
612            .datum_at(idx as usize)
613            .unwrap()
614            .into_jsonb()
615            .to_string();
616        Ok(env.new_string(jsonb)?)
617    })
618}
619
620#[unsafe(no_mangle)]
621extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue<'a>(
622    env: EnvParam<'a>,
623    pointer: Pointer<'a, JavaBindingIterator<'a>>,
624    idx: jint,
625) -> JObject<'a> {
626    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
627        let value = pointer
628            .as_ref()
629            .datum_at(idx as usize)
630            .unwrap()
631            .into_timestamp();
632
633        let sig = gen_jni_sig!(java.time.LocalDateTime of(int year, int month, int dayOfMonth, int hour, int minute, int second, int nanoOfSecond));
634
635        let (timestamp_class_ref, constructor) = pointer
636            .as_ref()
637            .class_cache
638            .timestamp_ctor
639            .get_or_try_init(|| {
640                let cls = env.find_class(gen_class_name!(java.time.LocalDateTime))?;
641                let init_method = env.get_static_method_id(&cls, "of", sig)?;
642                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
643            })?;
644        unsafe {
645            let JValueOwned::Object(timestamp_obj) = env.call_static_method_unchecked(
646                <&JClass<'_>>::from(timestamp_class_ref.as_obj()),
647                *constructor,
648                ReturnType::Object,
649                &[
650                    jvalue { i: value.0.year() },
651                    jvalue {
652                        i: value.0.month() as i32,
653                    },
654                    jvalue {
655                        i: value.0.day() as i32,
656                    },
657                    jvalue {
658                        i: value.0.hour() as i32,
659                    },
660                    jvalue {
661                        i: value.0.minute() as i32,
662                    },
663                    jvalue {
664                        i: value.0.second() as i32,
665                    },
666                    jvalue {
667                        i: value.0.nanosecond() as i32,
668                    },
669                ],
670            )?
671            else {
672                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
673                    name: "of".to_owned(),
674                    sig: sig.into(),
675                }));
676            };
677            Ok(timestamp_obj)
678        }
679    })
680}
681
682#[unsafe(no_mangle)]
683extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestamptzValue<'a>(
684    env: EnvParam<'a>,
685    pointer: Pointer<'a, JavaBindingIterator<'a>>,
686    idx: jint,
687) -> JObject<'a> {
688    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
689        let value = pointer
690            .as_ref()
691            .datum_at(idx as usize)
692            .unwrap()
693            .into_timestamptz();
694
695        let instant_sig =
696            gen_jni_sig!(java.time.Instant ofEpochSecond(long epochSecond, long nanoAdjustment));
697
698        let (instant_class_ref, instant_constructor) = pointer
699            .as_ref()
700            .class_cache
701            .instant_ctor
702            .get_or_try_init(|| {
703                let cls = env.find_class(gen_class_name!(java.time.Instant))?;
704                let init_method = env.get_static_method_id(&cls, "ofEpochSecond", instant_sig)?;
705                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
706            })?;
707        let instant_obj = unsafe {
708            let JValueOwned::Object(instant_obj) = env.call_static_method_unchecked(
709                <&JClass<'_>>::from(instant_class_ref.as_obj()),
710                *instant_constructor,
711                ReturnType::Object,
712                &[
713                    jvalue {
714                        j: value.timestamp(),
715                    },
716                    jvalue {
717                        j: value.timestamp_subsec_nanos() as i64,
718                    },
719                ],
720            )?
721            else {
722                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
723                    name: "ofEpochSecond".to_owned(),
724                    sig: instant_sig.into(),
725                }));
726            };
727            instant_obj
728        };
729
730        let utc_ref = pointer.as_ref().class_cache.utc.get_or_try_init(|| {
731            let cls = env.find_class(gen_class_name!(java.time.ZoneOffset))?;
732            let utc = env
733                .get_static_field(&cls, "UTC", gen_jni_type_sig!(java.time.ZoneOffset))?
734                .l()?;
735            env.new_global_ref(utc)
736        })?;
737
738        let sig = gen_jni_sig!(java.time.OffsetDateTime ofInstant(java.time.Instant instant, java.time.ZoneId zone));
739
740        let (timestamptz_class_ref, constructor) = pointer
741            .as_ref()
742            .class_cache
743            .timestamptz_ctor
744            .get_or_try_init(|| {
745                let cls = env.find_class(gen_class_name!(java.time.OffsetDateTime))?;
746                let init_method = env.get_static_method_id(&cls, "ofInstant", sig)?;
747                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
748            })?;
749        unsafe {
750            let JValueOwned::Object(timestamptz_obj) = env.call_static_method_unchecked(
751                <&JClass<'_>>::from(timestamptz_class_ref.as_obj()),
752                *constructor,
753                ReturnType::Object,
754                &[
755                    jvalue {
756                        l: instant_obj.as_raw(),
757                    },
758                    jvalue {
759                        l: utc_ref.as_obj().as_raw(),
760                    },
761                ],
762            )?
763            else {
764                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
765                    name: "ofInstant".to_owned(),
766                    sig: sig.into(),
767                }));
768            };
769            Ok(timestamptz_obj)
770        }
771    })
772}
773
774#[unsafe(no_mangle)]
775extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDecimalValue<'a>(
776    env: EnvParam<'a>,
777    pointer: Pointer<'a, JavaBindingIterator<'a>>,
778    idx: jint,
779) -> JObject<'a> {
780    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
781        let decimal_value = pointer
782            .as_ref()
783            .datum_at(idx as usize)
784            .unwrap()
785            .into_decimal();
786
787        match decimal_value {
788            Decimal::NaN | Decimal::NegativeInf | Decimal::PositiveInf => {
789                return Ok(JObject::null());
790            }
791            Decimal::Normalized(_) => {}
792        };
793
794        let value = decimal_value.to_string();
795        let string_value = env.new_string(value)?;
796        let (decimal_class_ref, constructor) = pointer
797            .as_ref()
798            .class_cache
799            .big_decimal_ctor
800            .get_or_try_init(|| {
801                let cls = env.find_class("java/math/BigDecimal")?;
802                let init_method = env.get_method_id(&cls, "<init>", "(Ljava/lang/String;)V")?;
803                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
804            })?;
805        unsafe {
806            let decimal_class = <&JClass<'_>>::from(decimal_class_ref.as_obj());
807            let date_obj = env.new_object_unchecked(
808                decimal_class,
809                *constructor,
810                &[jvalue {
811                    l: string_value.into_raw(),
812                }],
813            )?;
814            Ok(date_obj)
815        }
816    })
817}
818
819#[unsafe(no_mangle)]
820extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDateValue<'a>(
821    env: EnvParam<'a>,
822    pointer: Pointer<'a, JavaBindingIterator<'a>>,
823    idx: jint,
824) -> JObject<'a> {
825    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
826        let value = pointer.as_ref().datum_at(idx as usize).unwrap().into_date();
827        let epoch_days = (value.0 - DateTime::UNIX_EPOCH.date_naive()).num_days();
828
829        let sig = gen_jni_sig!(java.time.LocalDate ofEpochDay(long));
830
831        let (date_class_ref, constructor) =
832            pointer.as_ref().class_cache.date_ctor.get_or_try_init(|| {
833                let cls = env.find_class(gen_class_name!(java.time.LocalDate))?;
834                let init_method = env.get_static_method_id(&cls, "ofEpochDay", sig)?;
835                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
836            })?;
837        unsafe {
838            let JValueOwned::Object(date_obj) = env.call_static_method_unchecked(
839                <&JClass<'_>>::from(date_class_ref.as_obj()),
840                *constructor,
841                ReturnType::Object,
842                &[jvalue { j: epoch_days }],
843            )?
844            else {
845                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
846                    name: "ofEpochDay".to_owned(),
847                    sig: sig.into(),
848                }));
849            };
850            Ok(date_obj)
851        }
852    })
853}
854
855#[unsafe(no_mangle)]
856extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimeValue<'a>(
857    env: EnvParam<'a>,
858    pointer: Pointer<'a, JavaBindingIterator<'a>>,
859    idx: jint,
860) -> JObject<'a> {
861    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
862        let value = pointer.as_ref().datum_at(idx as usize).unwrap().into_time();
863
864        let sig = gen_jni_sig!(java.time.LocalTime of(int hour, int minute, int second, int nanoOfSecond));
865
866        let (time_class_ref, constructor) =
867            pointer.as_ref().class_cache.time_ctor.get_or_try_init(|| {
868                let cls = env.find_class(gen_class_name!(java.time.LocalTime))?;
869                let init_method = env.get_static_method_id(&cls, "of", sig)?;
870                Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
871            })?;
872        unsafe {
873            let JValueOwned::Object(time_obj) = env.call_static_method_unchecked(
874                <&JClass<'_>>::from(time_class_ref.as_obj()),
875                *constructor,
876                ReturnType::Object,
877                &[
878                    jvalue {
879                        i: value.0.hour() as i32,
880                    },
881                    jvalue {
882                        i: value.0.minute() as i32,
883                    },
884                    jvalue {
885                        i: value.0.second() as i32,
886                    },
887                    jvalue {
888                        i: value.0.nanosecond() as i32,
889                    },
890                ],
891            )?
892            else {
893                return Err(BindingError::from(jni::errors::Error::MethodNotFound {
894                    name: "of".to_owned(),
895                    sig: sig.into(),
896                }));
897            };
898            Ok(time_obj)
899        }
900    })
901}
902
903#[unsafe(no_mangle)]
904extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetByteaValue<'a>(
905    env: EnvParam<'a>,
906    pointer: Pointer<'a, JavaBindingIterator<'a>>,
907    idx: jint,
908) -> JByteArray<'a> {
909    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
910        let bytes = pointer
911            .as_ref()
912            .datum_at(idx as usize)
913            .unwrap()
914            .into_bytea();
915        Ok(env.byte_array_from_slice(bytes)?)
916    })
917}
918
919#[unsafe(no_mangle)]
920extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValue<'a>(
921    env: EnvParam<'a>,
922    pointer: Pointer<'a, JavaBindingIterator<'a>>,
923    idx: jint,
924    class: JClass<'a>,
925) -> JObject<'a> {
926    execute_and_catch(env, move |env: &mut EnvParam<'_>| {
927        let elems = pointer
928            .as_ref()
929            .datum_at(idx as usize)
930            .unwrap()
931            .into_list()
932            .iter();
933
934        // convert the Rust elements to a Java object array (Object[])
935        let jarray = env.new_object_array(elems.len() as jsize, &class, JObject::null())?;
936
937        for (i, ele) in elems.enumerate() {
938            let index = i as jsize;
939            match ele {
940                None => env.set_object_array_element(&jarray, i as jsize, JObject::null())?,
941                Some(val) => match val {
942                    ScalarRefImpl::Int16(v) => {
943                        let o = call_static_method!(
944                            env,
945                            {Short},
946                            {Short	valueOf(short s)},
947                            v
948                        )?;
949                        env.set_object_array_element(&jarray, index, &o)?;
950                    }
951                    ScalarRefImpl::Int32(v) => {
952                        let o = call_static_method!(
953                            env,
954                            {Integer},
955                            {Integer	valueOf(int i)},
956                            v
957                        )?;
958                        env.set_object_array_element(&jarray, index, &o)?;
959                    }
960                    ScalarRefImpl::Int64(v) => {
961                        let o = call_static_method!(
962                            env,
963                            {Long},
964                            {Long	valueOf(long l)},
965                            v
966                        )?;
967                        env.set_object_array_element(&jarray, index, &o)?;
968                    }
969                    ScalarRefImpl::Float32(v) => {
970                        let o = call_static_method!(
971                            env,
972                            {Float},
973                            {Float	valueOf(float f)},
974                            v.into_inner()
975                        )?;
976                        env.set_object_array_element(&jarray, index, &o)?;
977                    }
978                    ScalarRefImpl::Float64(v) => {
979                        let o = call_static_method!(
980                            env,
981                            {Double},
982                            {Double	valueOf(double d)},
983                            v.into_inner()
984                        )?;
985                        env.set_object_array_element(&jarray, index, &o)?;
986                    }
987                    ScalarRefImpl::Utf8(v) => {
988                        let obj = env.new_string(v)?;
989                        env.set_object_array_element(&jarray, index, obj)?
990                    }
991                    _ => env.set_object_array_element(&jarray, index, JObject::null())?,
992                },
993            }
994        }
995        let output = unsafe { JObject::from_raw(jarray.into_raw()) };
996        Ok(output)
997    })
998}
999
1000pub type JniSenderType<T> = Sender<anyhow::Result<T>>;
1001pub type JniReceiverType<T> = Receiver<T>;
1002
1003/// Send messages to the channel received by `CdcSplitReader`.
1004/// If msg is null, just check whether the channel is closed.
1005/// Return true if sending is successful, otherwise, return false so that caller can stop
1006/// gracefully.
1007#[unsafe(no_mangle)]
1008extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>(
1009    env: EnvParam<'a>,
1010    channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
1011    msg: JByteArray<'a>,
1012) -> jboolean {
1013    execute_and_catch(env, move |env| {
1014        // If msg is null means just check whether channel is closed.
1015        if msg.is_null() {
1016            if channel.as_ref().is_closed() {
1017                return Ok(JNI_FALSE);
1018            } else {
1019                return Ok(JNI_TRUE);
1020            }
1021        }
1022
1023        let get_event_stream_response: GetEventStreamResponse =
1024            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1025
1026        match channel
1027            .as_ref()
1028            .blocking_send(Ok(get_event_stream_response))
1029        {
1030            Ok(_) => Ok(JNI_TRUE),
1031            Err(e) => {
1032                tracing::info!(error = %e.as_report(), "send error");
1033                Ok(JNI_FALSE)
1034            }
1035        }
1036    })
1037}
1038
1039#[unsafe(no_mangle)]
1040extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceErrorToChannel<'a>(
1041    env: EnvParam<'a>,
1042    channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
1043    msg: JString<'a>,
1044) -> jboolean {
1045    execute_and_catch(env, move |env| {
1046        let ret = env.get_string(&msg);
1047        match ret {
1048            Ok(str) => {
1049                let err_msg: String = str.into();
1050                match channel.as_ref().blocking_send(Err(anyhow!(err_msg))) {
1051                    Ok(_) => Ok(JNI_TRUE),
1052                    Err(e) => {
1053                        tracing::info!(error = ?e.as_report(), "send error");
1054                        Ok(JNI_FALSE)
1055                    }
1056                }
1057            }
1058            Err(err) => {
1059                if msg.is_null() {
1060                    tracing::warn!("source error message is null");
1061                    Ok(JNI_TRUE)
1062                } else {
1063                    tracing::error!(error = ?err.as_report(), "source error message should be a java string");
1064                    Ok(JNI_FALSE)
1065                }
1066            }
1067        }
1068    })
1069}
1070
1071#[unsafe(no_mangle)]
1072extern "system" fn Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose(
1073    _env: EnvParam<'_>,
1074    channel: OwnedPointer<JniSenderType<GetEventStreamResponse>>,
1075) {
1076    channel.release();
1077}
1078
1079pub enum JniSinkWriterStreamRequest {
1080    PbRequest(SinkWriterStreamRequest),
1081    Chunk {
1082        epoch: u64,
1083        batch_id: u64,
1084        chunk: StreamChunk,
1085    },
1086}
1087
1088impl From<SinkWriterStreamRequest> for JniSinkWriterStreamRequest {
1089    fn from(value: SinkWriterStreamRequest) -> Self {
1090        Self::PbRequest(value)
1091    }
1092}
1093
1094#[unsafe(no_mangle)]
1095pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel<
1096    'a,
1097>(
1098    env: EnvParam<'a>,
1099    mut channel: Pointer<'a, JniReceiverType<JniSinkWriterStreamRequest>>,
1100) -> JObject<'a> {
1101    execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
1102        Some(msg) => {
1103            let obj = match msg {
1104                JniSinkWriterStreamRequest::PbRequest(request) => {
1105                    let bytes = env.byte_array_from_slice(&Message::encode_to_vec(&request))?;
1106                    let jobj = JObject::from(bytes);
1107                    call_static_method!(
1108                        env,
1109                        {com.risingwave.java.binding.JniSinkWriterStreamRequest},
1110                        {com.risingwave.java.binding.JniSinkWriterStreamRequest fromSerializedPayload(byte[] payload)},
1111                        &jobj
1112                    )?
1113                }
1114                JniSinkWriterStreamRequest::Chunk {
1115                    epoch,
1116                    batch_id,
1117                    chunk,
1118                } => {
1119                    let pointer = Box::into_raw(Box::new(chunk));
1120                    call_static_method!(
1121                        env,
1122                        {com.risingwave.java.binding.JniSinkWriterStreamRequest},
1123                        {com.risingwave.java.binding.JniSinkWriterStreamRequest fromStreamChunkOwnedPointer(long pointer, long epoch, long batchId)},
1124                        pointer as u64, epoch, batch_id
1125                    )
1126                    .inspect_err(|_| unsafe {
1127                        // release the stream chunk on err
1128                        drop(Box::from_raw(pointer));
1129                    })?
1130                }
1131            };
1132            Ok(obj)
1133        }
1134        None => Ok(JObject::null()),
1135    })
1136}
1137
1138#[unsafe(no_mangle)]
1139pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterResponseToChannel<
1140    'a,
1141>(
1142    env: EnvParam<'a>,
1143    channel: Pointer<'a, JniSenderType<SinkWriterStreamResponse>>,
1144    msg: JByteArray<'a>,
1145) -> jboolean {
1146    execute_and_catch(env, move |env| {
1147        let sink_writer_stream_response: SinkWriterStreamResponse =
1148            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1149
1150        match channel
1151            .as_ref()
1152            .blocking_send(Ok(sink_writer_stream_response))
1153        {
1154            Ok(_) => Ok(JNI_TRUE),
1155            Err(e) => {
1156                tracing::info!(error = ?e.as_report(), "send error");
1157                Ok(JNI_FALSE)
1158            }
1159        }
1160    })
1161}
1162
1163#[unsafe(no_mangle)]
1164pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterErrorToChannel<'a>(
1165    env: EnvParam<'a>,
1166    channel: Pointer<'a, Sender<anyhow::Result<SinkWriterStreamResponse>>>,
1167    msg: JString<'a>,
1168) -> jboolean {
1169    execute_and_catch(env, move |env| {
1170        let ret = env.get_string(&msg);
1171        match ret {
1172            Ok(str) => {
1173                let err_msg: String = str.into();
1174                match channel.as_ref().blocking_send(Err(anyhow!(err_msg))) {
1175                    Ok(_) => Ok(JNI_TRUE),
1176                    Err(e) => {
1177                        tracing::info!(error = ?e.as_report(), "send error");
1178                        Ok(JNI_FALSE)
1179                    }
1180                }
1181            }
1182            Err(err) => {
1183                if msg.is_null() {
1184                    tracing::warn!("sink error message is null");
1185                    Ok(JNI_TRUE)
1186                } else {
1187                    tracing::error!(error = ?err.as_report(), "sink error message should be a java string");
1188                    Ok(JNI_FALSE)
1189                }
1190            }
1191        }
1192    })
1193}
1194
1195#[unsafe(no_mangle)]
1196pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkCoordinatorRequestFromChannel<
1197    'a,
1198>(
1199    env: EnvParam<'a>,
1200    mut channel: Pointer<'a, JniReceiverType<SinkCoordinatorStreamRequest>>,
1201) -> JByteArray<'a> {
1202    execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
1203        Some(msg) => {
1204            let bytes = env
1205                .byte_array_from_slice(&Message::encode_to_vec(&msg))
1206                .unwrap();
1207            Ok(bytes)
1208        }
1209        None => Ok(JObject::null().into()),
1210    })
1211}
1212
1213#[unsafe(no_mangle)]
1214pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordinatorResponseToChannel<
1215    'a,
1216>(
1217    env: EnvParam<'a>,
1218    channel: Pointer<'a, JniSenderType<SinkCoordinatorStreamResponse>>,
1219    msg: JByteArray<'a>,
1220) -> jboolean {
1221    execute_and_catch(env, move |env| {
1222        let sink_coordinator_stream_response: SinkCoordinatorStreamResponse =
1223            Message::decode(to_guarded_slice(&msg, env)?.deref())?;
1224
1225        match channel
1226            .as_ref()
1227            .blocking_send(Ok(sink_coordinator_stream_response))
1228        {
1229            Ok(_) => Ok(JNI_TRUE),
1230            Err(e) => {
1231                tracing::info!(error = ?e.as_report(), "send error");
1232                Ok(JNI_FALSE)
1233            }
1234        }
1235    })
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240    use risingwave_common::types::Timestamptz;
1241
1242    /// make sure that the [`ScalarRefImpl::Int64`] received by
1243    /// [`Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue`]
1244    /// is of type [`DataType::Timestamptz`] stored in microseconds
1245    #[test]
1246    fn test_timestamptz_to_i64() {
1247        assert_eq!(
1248            "2023-06-01 09:45:00+08:00".parse::<Timestamptz>().unwrap(),
1249            Timestamptz::from_micros(1_685_583_900_000_000)
1250        );
1251    }
1252}