risingwave_expr/expr/
test_utils.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Helper functions to construct prost [`ExprNode`] for test.
16
17use std::num::NonZeroUsize;
18
19use num_traits::CheckedSub;
20use risingwave_common::types::{DataType, Interval, ScalarImpl};
21use risingwave_common::util::value_encoding::DatumToProtoExt;
22use risingwave_pb::data::PbDataType;
23use risingwave_pb::data::data_type::TypeName;
24use risingwave_pb::expr::expr_node::Type::Field;
25use risingwave_pb::expr::expr_node::{self, RexNode, Type};
26use risingwave_pb::expr::{ExprNode, FunctionCall};
27
28use super::{BoxedExpression, Result, build_from_prost};
29use crate::ExprError;
30
31pub fn make_func_call(kind: Type, ret: TypeName, children: Vec<ExprNode>) -> ExprNode {
32    ExprNode {
33        function_type: kind as i32,
34        return_type: Some(PbDataType {
35            type_name: ret as i32,
36            ..Default::default()
37        }),
38        rex_node: Some(RexNode::FuncCall(FunctionCall { children })),
39    }
40}
41
42pub fn make_input_ref(idx: usize, ret: TypeName) -> ExprNode {
43    ExprNode {
44        function_type: Type::Unspecified as i32,
45        return_type: Some(PbDataType {
46            type_name: ret as i32,
47            ..Default::default()
48        }),
49        rex_node: Some(RexNode::InputRef(idx as _)),
50    }
51}
52
53pub fn make_i32_literal(data: i32) -> ExprNode {
54    ExprNode {
55        function_type: Type::Unspecified as i32,
56        return_type: Some(PbDataType {
57            type_name: TypeName::Int32 as i32,
58            ..Default::default()
59        }),
60        rex_node: Some(RexNode::Constant(
61            Some(ScalarImpl::Int32(data)).to_protobuf(),
62        )),
63    }
64}
65
66fn make_interval_literal(data: Interval) -> ExprNode {
67    ExprNode {
68        function_type: Type::Unspecified as i32,
69        return_type: Some(PbDataType {
70            type_name: TypeName::Interval as i32,
71            ..Default::default()
72        }),
73        rex_node: Some(RexNode::Constant(
74            Some(ScalarImpl::Interval(data)).to_protobuf(),
75        )),
76    }
77}
78
79pub fn make_field_function(children: Vec<ExprNode>, ret: TypeName) -> ExprNode {
80    ExprNode {
81        function_type: Field as i32,
82        return_type: Some(PbDataType {
83            type_name: ret as i32,
84            ..Default::default()
85        }),
86        rex_node: Some(RexNode::FuncCall(FunctionCall { children })),
87    }
88}
89
90pub fn make_hop_window_expression(
91    time_col_data_type: DataType,
92    time_col_idx: usize,
93    window_size: Interval,
94    window_slide: Interval,
95    window_offset: Interval,
96) -> Result<(Vec<BoxedExpression>, Vec<BoxedExpression>)> {
97    let units = window_size
98        .exact_div(&window_slide)
99        .and_then(|x| NonZeroUsize::new(usize::try_from(x).ok()?))
100        .ok_or_else(|| ExprError::InvalidParam {
101            name: "window",
102            reason: format!(
103                "window_size {} cannot be divided by window_slide {}",
104                window_size, window_slide
105            )
106            .into(),
107        })?
108        .get();
109
110    let output_type = DataType::window_of(&time_col_data_type)
111        .unwrap()
112        .to_protobuf()
113        .type_name();
114
115    let time_col_ref = make_input_ref(time_col_idx, time_col_data_type.to_protobuf().type_name());
116
117    // The first window_start of hop window should be:
118    // tumble_start(`time_col` - (`window_size` - `window_slide`), `window_slide`, `window_offset`).
119    // Let's pre calculate (`window_size` - `window_slide`).
120    let window_size_sub_slide = window_size
121        .checked_sub(&window_slide)
122        .ok_or_else(|| ExprError::InvalidParam {
123            name: "window",
124            reason: format!(
125                "window_size {} cannot be subtracted by window_slide {}",
126                window_size, window_slide
127            )
128            .into(),
129        })
130        .unwrap();
131
132    let hop_window_start = make_func_call(
133        expr_node::Type::TumbleStart,
134        output_type,
135        vec![
136            make_func_call(
137                expr_node::Type::Subtract,
138                output_type,
139                vec![time_col_ref, make_interval_literal(window_size_sub_slide)],
140            ),
141            make_interval_literal(window_slide),
142            make_interval_literal(window_offset),
143        ],
144    );
145
146    let mut window_start_exprs = Vec::with_capacity(units);
147    let mut window_end_exprs = Vec::with_capacity(units);
148    for i in 0..units {
149        let window_start_offset =
150            window_slide
151                .checked_mul_int(i)
152                .ok_or_else(|| ExprError::InvalidParam {
153                    name: "window",
154                    reason: format!(
155                        "window_slide {} cannot be multiplied by {}",
156                        window_slide, i
157                    )
158                    .into(),
159                })?;
160        let window_end_offset =
161            window_slide
162                .checked_mul_int(i + units)
163                .ok_or_else(|| ExprError::InvalidParam {
164                    name: "window",
165                    reason: format!(
166                        "window_slide {} cannot be multiplied by {}",
167                        window_slide, i
168                    )
169                    .into(),
170                })?;
171        let window_start_expr = make_func_call(
172            expr_node::Type::Add,
173            output_type,
174            vec![
175                hop_window_start.clone(),
176                make_interval_literal(window_start_offset),
177            ],
178        );
179        window_start_exprs.push(build_from_prost(&window_start_expr).unwrap());
180        let window_end_expr = make_func_call(
181            expr_node::Type::Add,
182            output_type,
183            vec![
184                hop_window_start.clone(),
185                make_interval_literal(window_end_offset),
186            ],
187        );
188        window_end_exprs.push(build_from_prost(&window_end_expr).unwrap());
189    }
190    Ok((window_start_exprs, window_end_exprs))
191}