risingwave_expr_impl/scalar/
vnode.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::array::{ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I16ArrayBuilder};
20use risingwave_common::hash::VirtualNode;
21use risingwave_common::row::OwnedRow;
22use risingwave_common::types::{DataType, Datum};
23use risingwave_expr::expr::{BoxedExpression, Expression};
24use risingwave_expr::{Result, build_function, expr_context};
25
26#[derive(Debug)]
27struct VnodeExpression {
28    /// `Some` if it's from the first argument of user-facing function `VnodeUser` (`rw_vnode`),
29    /// `None` if it's from the internal function `Vnode`.
30    vnode_count: Option<usize>,
31
32    /// A list of expressions to get the distribution key columns. Typically `InputRef`.
33    children: Vec<BoxedExpression>,
34
35    /// Normally, we pass the distribution key indices to `VirtualNode::compute_xx` functions.
36    /// But in this case, all children columns are used to compute vnode. So we cache a vector of
37    /// all indices here and pass it later to reduce allocation.
38    all_indices: Vec<usize>,
39}
40
41#[build_function("vnode(...) -> int2")]
42fn build(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
43    Ok(Box::new(VnodeExpression {
44        vnode_count: None,
45        all_indices: (0..children.len()).collect(),
46        children,
47    }))
48}
49
50#[build_function("vnode_user(...) -> int2")]
51fn build_user(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
52    let mut children = children.into_iter();
53
54    let vnode_count = children
55        .next()
56        .unwrap() // always exist, argument number enforced in binder
57        .eval_const() // required to be constant
58        .context("the first argument (vnode count) must be a constant")?
59        .context("the first argument (vnode count) must not be NULL")?
60        .into_int32(); // always int32, casted during type inference
61
62    if !(1i32..=VirtualNode::MAX_COUNT as i32).contains(&vnode_count) {
63        return Err(anyhow::anyhow!(
64            "the first argument (vnode count) must be in range 1..={}",
65            VirtualNode::MAX_COUNT
66        )
67        .into());
68    }
69
70    let children = children.collect_vec();
71
72    Ok(Box::new(VnodeExpression {
73        vnode_count: Some(vnode_count.try_into().unwrap()),
74        all_indices: (0..children.len()).collect(),
75        children,
76    }))
77}
78
79#[async_trait::async_trait]
80impl Expression for VnodeExpression {
81    fn return_type(&self) -> DataType {
82        DataType::Int16
83    }
84
85    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
86        let mut arrays = Vec::with_capacity(self.children.len());
87        for child in &self.children {
88            arrays.push(child.eval(input).await?);
89        }
90        let input = DataChunk::new(arrays, input.visibility().clone());
91
92        let vnodes = VirtualNode::compute_chunk(&input, &self.all_indices, self.vnode_count()?);
93        let mut builder = I16ArrayBuilder::new(input.capacity());
94        vnodes
95            .into_iter()
96            .for_each(|vnode| builder.append(Some(vnode.to_scalar())));
97        Ok(Arc::new(ArrayImpl::from(builder.finish())))
98    }
99
100    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
101        let mut datums = Vec::with_capacity(self.children.len());
102        for child in &self.children {
103            datums.push(child.eval_row(input).await?);
104        }
105        let input = OwnedRow::new(datums);
106
107        Ok(Some(
108            VirtualNode::compute_row(input, &self.all_indices, self.vnode_count()?)
109                .to_scalar()
110                .into(),
111        ))
112    }
113}
114
115impl VnodeExpression {
116    fn vnode_count(&self) -> Result<usize> {
117        if let Some(vnode_count) = self.vnode_count {
118            Ok(vnode_count)
119        } else {
120            expr_context::vnode_count()
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use risingwave_common::array::{DataChunk, DataChunkTestExt};
128    use risingwave_common::row::Row;
129    use risingwave_expr::expr::build_from_pretty;
130    use risingwave_expr::expr_context::VNODE_COUNT;
131
132    #[tokio::test]
133    async fn test_vnode_expr_eval() {
134        let vnode_count = 32;
135        let expr = build_from_pretty("(vnode:int2 $0:int4 $0:int8 $0:varchar)");
136        let input = DataChunk::from_pretty(
137            "i  I  T
138             1  10 abc
139             2  32 def
140             3  88 ghi",
141        );
142
143        // test eval
144        let output = VNODE_COUNT::scope(vnode_count, expr.eval(&input))
145            .await
146            .unwrap();
147        for vnode in output.iter() {
148            let vnode = vnode.unwrap().into_int16();
149            assert!((0..vnode_count as i16).contains(&vnode));
150        }
151
152        // test eval_row
153        for row in input.rows() {
154            let result = VNODE_COUNT::scope(vnode_count, expr.eval_row(&row.to_owned_row()))
155                .await
156                .unwrap();
157            let vnode = result.unwrap().into_int16();
158            assert!((0..vnode_count as i16).contains(&vnode));
159        }
160    }
161}