risingwave_jni_core/
jvm_runtime.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ffi::c_void;
16use std::path::PathBuf;
17
18use anyhow::{Context, bail};
19use fs_err as fs;
20use fs_err::PathExt;
21use jni::objects::{JObject, JString};
22use jni::{AttachGuard, InitArgsBuilder, JNIEnv, JNIVersion, JavaVM};
23use risingwave_common::global_jvm::JVM;
24use risingwave_common::util::resource_util::memory::system_memory_available_bytes;
25use thiserror_ext::AsReport;
26use tracing::error;
27
28use crate::{call_method, call_static_method};
29
30/// Use 10% of compute total memory by default. Compute node uses 0.7 * system memory by default.
31const DEFAULT_MEMORY_PROPORTION: f64 = 0.07;
32
33fn locate_libs_path() -> anyhow::Result<PathBuf> {
34    let libs_path = if let Ok(libs_path) = std::env::var("CONNECTOR_LIBS_PATH") {
35        PathBuf::from(libs_path)
36    } else {
37        tracing::info!(
38            "environment variable CONNECTOR_LIBS_PATH is not specified, use default path `./libs` instead"
39        );
40        std::env::current_exe()
41            .and_then(|p| p.fs_err_canonicalize()) // resolve symlink of the current executable
42            .context("unable to get path of the executable")?
43            .parent()
44            .expect("not root")
45            .join("libs")
46    };
47    Ok(libs_path)
48}
49
50pub fn build_jvm_with_native_registration() -> anyhow::Result<JavaVM> {
51    let libs_path = locate_libs_path().context("failed to locate connector libs")?;
52    tracing::info!(path = %libs_path.display(), "located connector libs");
53
54    let mut class_vec = vec![];
55
56    let entries = fs::read_dir(&libs_path).context(if cfg!(debug_assertions) {
57        "failed to read connector libs; \
58        for RiseDev users, please check if ENABLE_BUILD_RW_CONNECTOR is set with `risedev configure`"
59    } else {
60        "failed to read connector libs, \
61        please check if env var CONNECTOR_LIBS_PATH is correctly configured"
62    })?;
63    for entry in entries.flatten() {
64        let entry_path = entry.path();
65        if entry_path.file_name().is_some() {
66            let path = fs::canonicalize(entry_path)
67                .expect("invalid entry_path obtained from fs::read_dir");
68            class_vec.push(path.to_str().unwrap().to_owned());
69        }
70    }
71
72    // move risingwave-source-cdc to the head of classpath, because we have some patched Debezium classes
73    // in this jar which needs to be loaded first.
74    let mut new_class_vec = Vec::with_capacity(class_vec.len());
75    for path in class_vec {
76        if path.contains("risingwave-source-cdc") {
77            new_class_vec.insert(0, path.clone());
78        } else {
79            new_class_vec.push(path.clone());
80        }
81    }
82    class_vec = new_class_vec;
83
84    let jvm_heap_size = if let Ok(heap_size) = std::env::var("JVM_HEAP_SIZE") {
85        heap_size
86    } else {
87        format!(
88            "{}",
89            (system_memory_available_bytes() as f64 * DEFAULT_MEMORY_PROPORTION) as usize
90        )
91    };
92
93    // FIXME: passing custom arguments to the embedded jvm when compute node start
94    // Build the VM properties
95    let args_builder = InitArgsBuilder::new()
96        // Pass the JNI API version (default is 8)
97        .version(JNIVersion::V8)
98        .option("-Dis_embedded_connector=true")
99        .option(format!("-Djava.class.path={}", class_vec.join(":")))
100        .option("--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED")
101        .option("-Xms16m")
102        .option(format!("-Xmx{}", jvm_heap_size));
103
104    tracing::info!("JVM args: {:?}", args_builder);
105    let jvm_args = args_builder.build().context("invalid jvm args")?;
106
107    // Create a new VM
108    let jvm = match JavaVM::new(jvm_args) {
109        Err(err) => {
110            tracing::error!(error = ?err.as_report(), "fail to new JVM");
111            bail!("fail to new JVM");
112        }
113        Ok(jvm) => jvm,
114    };
115
116    tracing::info!("initialize JVM successfully");
117
118    let result: std::result::Result<(), jni::errors::Error> = try {
119        let mut env = jvm_env(&jvm)?;
120        register_java_binding_native_methods(&mut env)?;
121    };
122
123    result.context("failed to register native method")?;
124
125    Ok(jvm)
126}
127
128pub fn jvm_env(jvm: &JavaVM) -> Result<AttachGuard<'_>, jni::errors::Error> {
129    jvm.attach_current_thread()
130        .inspect_err(|e| tracing::error!(error = ?e.as_report(), "jvm attach thread error"))
131}
132
133pub fn register_java_binding_native_methods(
134    env: &mut JNIEnv<'_>,
135) -> Result<(), jni::errors::Error> {
136    let binding_class = env
137        .find_class(gen_class_name!(com.risingwave.java.binding.Binding))
138        .inspect_err(|e| tracing::error!(error = ?e.as_report(), "jvm find class error"))?;
139    use crate::*;
140    macro_rules! gen_native_method_array {
141        () => {{
142            $crate::for_all_native_methods! {gen_native_method_array}
143        }};
144        ({$({ $func_name:ident, {$($ret:tt)+}, {$($args:tt)*} })*}) => {
145            [
146                $(
147                    $crate::gen_native_method_entry! {
148                        Java_com_risingwave_java_binding_Binding_, $func_name, {$($ret)+}, {$($args)*}
149                    },
150                )*
151            ]
152        }
153    }
154    env.register_native_methods(binding_class, &gen_native_method_array!())
155        .inspect_err(
156            |e| tracing::error!(error = ?e.as_report(), "jvm register native methods error"),
157        )?;
158
159    tracing::info!("register native methods for jvm successfully");
160    Ok(())
161}
162
163/// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize,
164/// return zero.
165pub fn load_jvm_memory_stats() -> (usize, usize) {
166    match JVM.get() {
167        Some(jvm) => {
168            let result: Result<(usize, usize), anyhow::Error> = try {
169                execute_with_jni_env(jvm, |env| {
170                    let runtime_instance = crate::call_static_method!(
171                        env,
172                        {Runtime},
173                        {Runtime getRuntime()}
174                    )?;
175
176                    let total_memory =
177                        call_method!(env, runtime_instance.as_ref(), {long totalMemory()})?;
178                    let free_memory =
179                        call_method!(env, runtime_instance.as_ref(), {long freeMemory()})?;
180
181                    Ok((total_memory as usize, (total_memory - free_memory) as usize))
182                })?
183            };
184            match result {
185                Ok(ret) => ret,
186                Err(e) => {
187                    error!(error = ?e.as_report(), "failed to collect jvm stats");
188                    (0, 0)
189                }
190            }
191        }
192        _ => (0, 0),
193    }
194}
195
196pub fn execute_with_jni_env<T>(
197    jvm: &JavaVM,
198    f: impl FnOnce(&mut JNIEnv<'_>) -> anyhow::Result<T>,
199) -> anyhow::Result<T> {
200    let mut env = jvm
201        .attach_current_thread()
202        .with_context(|| "Failed to attach current rust thread to jvm")?;
203
204    // set context class loader for the thread
205    // java.lang.Thread.currentThread()
206    //     .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
207
208    let thread = crate::call_static_method!(
209        env,
210        {Thread},
211        {Thread currentThread()}
212    )?;
213
214    let system_class_loader = crate::call_static_method!(
215        env,
216        {ClassLoader},
217        {ClassLoader getSystemClassLoader()}
218    )?;
219
220    crate::call_method!(
221        env,
222        thread,
223        {void setContextClassLoader(ClassLoader)},
224        &system_class_loader
225    )?;
226
227    let ret = f(&mut env);
228
229    match env.exception_check() {
230        Ok(true) => {
231            let exception = env.exception_occurred().inspect_err(|e| {
232                tracing::warn!(error = %e.as_report(), "Failed to get jvm exception");
233            })?;
234            env.exception_describe().inspect_err(|e| {
235                tracing::warn!(error = %e.as_report(), "Failed to describe jvm exception");
236            })?;
237            env.exception_clear().inspect_err(|e| {
238                tracing::warn!(error = %e.as_report(), "Exception occurred but failed to clear");
239            })?;
240            let message = call_method!(env, exception, {String getMessage()})?;
241            let message = jobj_to_str(&mut env, message)?;
242            return Err(anyhow::anyhow!("Caught Java Exception: {}", message));
243        }
244        Ok(false) => {
245            // No exception, do nothing
246        }
247        Err(e) => {
248            tracing::warn!(error = %e.as_report(), "Failed to check exception");
249        }
250    }
251
252    ret
253}
254
255/// A helper method to convert an java object to rust string.
256pub fn jobj_to_str(env: &mut JNIEnv<'_>, obj: JObject<'_>) -> anyhow::Result<String> {
257    if !env.is_instance_of(&obj, "java/lang/String")? {
258        bail!("Input object is not a java string and can't be converted!")
259    }
260    let jstr = JString::from(obj);
261    let java_str = env.get_string(&jstr)?;
262    Ok(java_str.to_str()?.to_owned())
263}
264
265/// Dumps the JVM stack traces.
266///
267/// # Returns
268///
269/// - `Ok(None)` if JVM is not initialized.
270/// - `Ok(Some(String))` if JVM is initialized and stack traces are dumped.
271/// - `Err` if failed to dump stack traces.
272pub fn dump_jvm_stack_traces() -> anyhow::Result<Option<String>> {
273    match JVM.get() {
274        None => Ok(None),
275        Some(jvm) => execute_with_jni_env(jvm, |env| {
276            let result = call_static_method!(
277                env,
278                {com.risingwave.connector.api.Monitor},
279                {String dumpStackTrace()}
280            )
281            .with_context(|| "Failed to call Java function")?;
282            let result = JString::from(result);
283            let result = env
284                .get_string(&result)
285                .with_context(|| "Failed to convert JString")?;
286            let result = result
287                .to_str()
288                .with_context(|| "Failed to convert JavaStr")?;
289            Ok(Some(result.to_owned()))
290        }),
291    }
292}
293
294/// Register the JVM initialization closure.
295pub fn register_jvm_builder() {
296    JVM.register_jvm_builder(Box::new(|| {
297        build_jvm_with_native_registration().expect("failed to build JVM with native registration")
298    }));
299}