risingwave_expr_impl/udf/
quickjs.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use arrow_udf_js::{AggregateOptions, FunctionOptions, Runtime};
16use futures_util::{FutureExt, StreamExt};
17use risingwave_common::array::arrow::arrow_schema_udf::{DataType, Field};
18use risingwave_common::array::arrow::{UdfArrowConvert, UdfToArrow};
19
20use super::*;
21
22#[linkme::distributed_slice(UDF_IMPLS)]
23static QUICKJS: UdfImplDescriptor = UdfImplDescriptor {
24    match_fn: |language, runtime, _link| {
25        language == "javascript" && matches!(runtime, None | Some("quickjs"))
26    },
27    create_fn: |opts| {
28        Ok(CreateFunctionOutput {
29            name_in_runtime: opts.name.to_owned(),
30            body: Some(opts.as_.context("AS must be specified")?.to_owned()),
31            compressed_binary: None,
32        })
33    },
34    build_fn: |opts| {
35        // NOTE: Some function calls such as `add_function()` requires async.
36        // However, since the `Runtime` here is not shared, the async block will never block.
37        futures::executor::block_on(async {
38            let mut runtime = Runtime::new()
39                .await
40                .context("failed to create QuickJS Runtime")?;
41            if opts.is_async.unwrap_or(false) {
42                runtime
43                    .enable_fetch()
44                    .await
45                    .context("failed to enable fetch")?;
46            }
47            if opts.kind.is_aggregate() {
48                let options = AggregateOptions {
49                    is_async: opts.is_async.unwrap_or(false),
50                    ..Default::default()
51                };
52                runtime
53                    .add_aggregate(
54                        opts.name_in_runtime,
55                        Field::new("state", DataType::Binary, true).with_metadata(
56                            [("ARROW:extension:name".into(), "arrowudf.json".into())].into(),
57                        ),
58                        UdfArrowConvert::default().to_arrow_field("", opts.return_type)?,
59                        opts.body.context("body is required")?,
60                        options,
61                    )
62                    .await
63                    .context("failed to add_aggregate")?;
64            } else {
65                let options = FunctionOptions {
66                    is_async: opts.is_async.unwrap_or(false),
67                    is_batched: opts.is_batched.unwrap_or(false),
68                    ..Default::default()
69                };
70                let res = runtime
71                    .add_function(
72                        opts.name_in_runtime,
73                        UdfArrowConvert::default().to_arrow_field("", opts.return_type)?,
74                        opts.body.context("body is required")?,
75                        options.clone(),
76                    )
77                    .await;
78
79                if res.is_err() {
80                    // COMPATIBILITY: This is for keeping compatible with the legacy syntax that
81                    // only function body is provided by users.
82                    let body = format!(
83                        "export function{} {}({}) {{ {} }}",
84                        if opts.kind.is_table() { "*" } else { "" },
85                        opts.name_in_runtime,
86                        opts.arg_names.join(","),
87                        opts.body.context("body is required")?,
88                    );
89                    runtime
90                        .add_function(
91                            opts.name_in_runtime,
92                            UdfArrowConvert::default().to_arrow_field("", opts.return_type)?,
93                            &body,
94                            options,
95                        )
96                        .await
97                        .context("failed to add_function")?;
98                }
99            }
100            Ok(Box::new(QuickJsFunction {
101                runtime,
102                name: opts.name_in_runtime.to_owned(),
103            }) as Box<dyn UdfImpl>)
104        })
105    },
106};
107
108#[derive(Debug)]
109struct QuickJsFunction {
110    runtime: Runtime,
111    name: String,
112}
113
114#[async_trait::async_trait]
115impl UdfImpl for QuickJsFunction {
116    async fn call(&self, input: &RecordBatch) -> Result<RecordBatch> {
117        // TODO(eric): if not batched, call JS function row by row. Otherwise, one row failure will fail the entire chunk.
118        self.runtime.call(&self.name, input).await
119    }
120
121    async fn call_table_function<'a>(
122        &'a self,
123        input: &'a RecordBatch,
124    ) -> Result<BoxStream<'a, Result<RecordBatch>>> {
125        let iter = self.runtime.call_table_function(&self.name, input, 1024)?;
126        Ok(Box::pin(iter))
127    }
128
129    async fn call_agg_create_state(&self) -> Result<ArrayRef> {
130        self.runtime.create_state(&self.name).await
131    }
132
133    async fn call_agg_accumulate_or_retract(
134        &self,
135        state: &ArrayRef,
136        ops: &BooleanArray,
137        input: &RecordBatch,
138    ) -> Result<ArrayRef> {
139        self.runtime
140            .accumulate_or_retract(&self.name, state, ops, input)
141            .await
142    }
143
144    async fn call_agg_finish(&self, state: &ArrayRef) -> Result<ArrayRef> {
145        self.runtime.finish(&self.name, state).await
146    }
147
148    fn memory_usage(&self) -> usize {
149        futures::executor::block_on(async {
150            // As `runtime` is not shared, this await will never block.
151            self.runtime.inner().memory_usage().await.malloc_size as usize
152        })
153    }
154}