risingwave_expr_impl/udf/
quickjs.rs1use arrow_udf_js::{AggregateOptions, FunctionOptions, Runtime};
16use futures_util::{FutureExt, StreamExt};
17use risingwave_common::array::arrow::arrow_schema_udf::{DataType, Field};
18use risingwave_common::array::arrow::{UdfArrowConvert, UdfToArrow};
19
20use super::*;
21
22#[linkme::distributed_slice(UDF_IMPLS)]
23static QUICKJS: UdfImplDescriptor = UdfImplDescriptor {
24 match_fn: |language, runtime, _link| {
25 language == "javascript" && matches!(runtime, None | Some("quickjs"))
26 },
27 create_fn: |opts| {
28 Ok(CreateFunctionOutput {
29 name_in_runtime: opts.name.to_owned(),
30 body: Some(opts.as_.context("AS must be specified")?.to_owned()),
31 compressed_binary: None,
32 })
33 },
34 build_fn: |opts| {
35 futures::executor::block_on(async {
38 let mut runtime = Runtime::new()
39 .await
40 .context("failed to create QuickJS Runtime")?;
41 if opts.is_async.unwrap_or(false) {
42 runtime
43 .enable_fetch()
44 .await
45 .context("failed to enable fetch")?;
46 }
47 if opts.kind.is_aggregate() {
48 let options = AggregateOptions {
49 is_async: opts.is_async.unwrap_or(false),
50 ..Default::default()
51 };
52 runtime
53 .add_aggregate(
54 opts.name_in_runtime,
55 Field::new("state", DataType::Binary, true).with_metadata(
56 [("ARROW:extension:name".into(), "arrowudf.json".into())].into(),
57 ),
58 UdfArrowConvert::default().to_arrow_field("", opts.return_type)?,
59 opts.body.context("body is required")?,
60 options,
61 )
62 .await
63 .context("failed to add_aggregate")?;
64 } else {
65 let options = FunctionOptions {
66 is_async: opts.is_async.unwrap_or(false),
67 is_batched: opts.is_batched.unwrap_or(false),
68 ..Default::default()
69 };
70 let res = runtime
71 .add_function(
72 opts.name_in_runtime,
73 UdfArrowConvert::default().to_arrow_field("", opts.return_type)?,
74 opts.body.context("body is required")?,
75 options.clone(),
76 )
77 .await;
78
79 if res.is_err() {
80 let body = format!(
83 "export function{} {}({}) {{ {} }}",
84 if opts.kind.is_table() { "*" } else { "" },
85 opts.name_in_runtime,
86 opts.arg_names.join(","),
87 opts.body.context("body is required")?,
88 );
89 runtime
90 .add_function(
91 opts.name_in_runtime,
92 UdfArrowConvert::default().to_arrow_field("", opts.return_type)?,
93 &body,
94 options,
95 )
96 .await
97 .context("failed to add_function")?;
98 }
99 }
100 Ok(Box::new(QuickJsFunction {
101 runtime,
102 name: opts.name_in_runtime.to_owned(),
103 }) as Box<dyn UdfImpl>)
104 })
105 },
106};
107
108#[derive(Debug)]
109struct QuickJsFunction {
110 runtime: Runtime,
111 name: String,
112}
113
114#[async_trait::async_trait]
115impl UdfImpl for QuickJsFunction {
116 async fn call(&self, input: &RecordBatch) -> Result<RecordBatch> {
117 self.runtime.call(&self.name, input).await
119 }
120
121 async fn call_table_function<'a>(
122 &'a self,
123 input: &'a RecordBatch,
124 ) -> Result<BoxStream<'a, Result<RecordBatch>>> {
125 let iter = self.runtime.call_table_function(&self.name, input, 1024)?;
126 Ok(Box::pin(iter))
127 }
128
129 async fn call_agg_create_state(&self) -> Result<ArrayRef> {
130 self.runtime.create_state(&self.name).await
131 }
132
133 async fn call_agg_accumulate_or_retract(
134 &self,
135 state: &ArrayRef,
136 ops: &BooleanArray,
137 input: &RecordBatch,
138 ) -> Result<ArrayRef> {
139 self.runtime
140 .accumulate_or_retract(&self.name, state, ops, input)
141 .await
142 }
143
144 async fn call_agg_finish(&self, state: &ArrayRef) -> Result<ArrayRef> {
145 self.runtime.finish(&self.name, state).await
146 }
147
148 fn memory_usage(&self) -> usize {
149 futures::executor::block_on(async {
150 self.runtime.inner().memory_usage().await.malloc_size as usize
152 })
153 }
154}