risingwave_expr_impl/scalar/
vnode.rs1use std::sync::Arc;
16
17use anyhow::Context;
18use itertools::Itertools;
19use risingwave_common::array::{ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I16ArrayBuilder};
20use risingwave_common::hash::VirtualNode;
21use risingwave_common::row::OwnedRow;
22use risingwave_common::types::{DataType, Datum};
23use risingwave_expr::expr::{BoxedExpression, Expression};
24use risingwave_expr::{Result, build_function, expr_context};
25
26#[derive(Debug)]
27struct VnodeExpression {
28 vnode_count: Option<usize>,
31
32 children: Vec<BoxedExpression>,
34
35 all_indices: Vec<usize>,
39}
40
41#[build_function("vnode(...) -> int2")]
42fn build(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
43 Ok(Box::new(VnodeExpression {
44 vnode_count: None,
45 all_indices: (0..children.len()).collect(),
46 children,
47 }))
48}
49
50#[build_function("vnode_user(...) -> int2")]
51fn build_user(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
52 let mut children = children.into_iter();
53
54 let vnode_count = children
55 .next()
56 .unwrap() .eval_const() .context("the first argument (vnode count) must be a constant")?
59 .context("the first argument (vnode count) must not be NULL")?
60 .into_int32(); if !(1i32..=VirtualNode::MAX_COUNT as i32).contains(&vnode_count) {
63 return Err(anyhow::anyhow!(
64 "the first argument (vnode count) must be in range 1..={}",
65 VirtualNode::MAX_COUNT
66 )
67 .into());
68 }
69
70 let children = children.collect_vec();
71
72 Ok(Box::new(VnodeExpression {
73 vnode_count: Some(vnode_count.try_into().unwrap()),
74 all_indices: (0..children.len()).collect(),
75 children,
76 }))
77}
78
79#[async_trait::async_trait]
80impl Expression for VnodeExpression {
81 fn return_type(&self) -> DataType {
82 DataType::Int16
83 }
84
85 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
86 let mut arrays = Vec::with_capacity(self.children.len());
87 for child in &self.children {
88 arrays.push(child.eval(input).await?);
89 }
90 let input = DataChunk::new(arrays, input.visibility().clone());
91
92 let vnodes = VirtualNode::compute_chunk(&input, &self.all_indices, self.vnode_count()?);
93 let mut builder = I16ArrayBuilder::new(input.capacity());
94 vnodes
95 .into_iter()
96 .for_each(|vnode| builder.append(Some(vnode.to_scalar())));
97 Ok(Arc::new(ArrayImpl::from(builder.finish())))
98 }
99
100 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
101 let mut datums = Vec::with_capacity(self.children.len());
102 for child in &self.children {
103 datums.push(child.eval_row(input).await?);
104 }
105 let input = OwnedRow::new(datums);
106
107 Ok(Some(
108 VirtualNode::compute_row(input, &self.all_indices, self.vnode_count()?)
109 .to_scalar()
110 .into(),
111 ))
112 }
113}
114
115impl VnodeExpression {
116 fn vnode_count(&self) -> Result<usize> {
117 if let Some(vnode_count) = self.vnode_count {
118 Ok(vnode_count)
119 } else {
120 expr_context::vnode_count()
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use risingwave_common::array::{DataChunk, DataChunkTestExt};
128 use risingwave_common::row::Row;
129 use risingwave_expr::expr::build_from_pretty;
130 use risingwave_expr::expr_context::VNODE_COUNT;
131
132 #[tokio::test]
133 async fn test_vnode_expr_eval() {
134 let vnode_count = 32;
135 let expr = build_from_pretty("(vnode:int2 $0:int4 $0:int8 $0:varchar)");
136 let input = DataChunk::from_pretty(
137 "i I T
138 1 10 abc
139 2 32 def
140 3 88 ghi",
141 );
142
143 let output = VNODE_COUNT::scope(vnode_count, expr.eval(&input))
145 .await
146 .unwrap();
147 for vnode in output.iter() {
148 let vnode = vnode.unwrap().into_int16();
149 assert!((0..vnode_count as i16).contains(&vnode));
150 }
151
152 for row in input.rows() {
154 let result = VNODE_COUNT::scope(vnode_count, expr.eval_row(&row.to_owned_row()))
155 .await
156 .unwrap();
157 let vnode = result.unwrap().into_int16();
158 assert!((0..vnode_count as i16).contains(&vnode));
159 }
160 }
161}