risingwave_expr_impl/table_function/
generate_series.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use chrono_tz::Tz;
16use num_traits::One;
17use risingwave_common::types::{CheckedAdd, Decimal, Interval, IsNegative, Timestamptz};
18use risingwave_expr::expr_context::TIME_ZONE;
19use risingwave_expr::{ExprError, Result, capture_context, function};
20
21#[function("generate_series(int4, int4) -> setof int4")]
22#[function("generate_series(int8, int8) -> setof int8")]
23fn generate_series<T>(start: T, stop: T) -> Result<impl Iterator<Item = T>>
24where
25    T: CheckedAdd<Output = T> + PartialOrd + Copy + One + IsNegative,
26{
27    range_generic::<_, _, _, true>(start, stop, T::one(), ())
28}
29
30#[function("generate_series(decimal, decimal) -> setof decimal")]
31fn generate_series_decimal(start: Decimal, stop: Decimal) -> Result<impl Iterator<Item = Decimal>>
32where
33{
34    validate_range_parameters(start, stop, Decimal::one())?;
35    range_generic::<Decimal, Decimal, _, true>(start, stop, Decimal::one(), ())
36}
37
38#[function("generate_series(int4, int4, int4) -> setof int4")]
39#[function("generate_series(int8, int8, int8) -> setof int8")]
40#[function("generate_series(timestamp, timestamp, interval) -> setof timestamp")]
41fn generate_series_step<T, S>(start: T, stop: T, step: S) -> Result<impl Iterator<Item = T>>
42where
43    T: CheckedAdd<S, Output = T> + PartialOrd + Copy,
44    S: IsNegative + Copy,
45{
46    range_generic::<_, _, _, true>(start, stop, step, ())
47}
48
49#[function("generate_series(decimal, decimal, decimal) -> setof decimal")]
50fn generate_series_step_decimal(
51    start: Decimal,
52    stop: Decimal,
53    step: Decimal,
54) -> Result<impl Iterator<Item = Decimal>> {
55    validate_range_parameters(start, stop, step)?;
56    range_generic::<_, _, _, true>(start, stop, step, ())
57}
58
59#[function("generate_series(timestamptz, timestamptz, interval) -> setof timestamptz")]
60fn generate_series_timestamptz_session(
61    start: Timestamptz,
62    stop: Timestamptz,
63    step: Interval,
64) -> Result<impl Iterator<Item = Timestamptz>> {
65    generate_series_timestamptz_impl_captured(start, stop, step)
66}
67
68#[function("generate_series(timestamptz, timestamptz, interval, varchar) -> setof timestamptz")]
69fn generate_series_timestamptz_at_zone(
70    start: Timestamptz,
71    stop: Timestamptz,
72    step: Interval,
73    time_zone: &str,
74) -> Result<impl Iterator<Item = Timestamptz>> {
75    generate_series_timestamptz_impl(time_zone, start, stop, step)
76}
77
78#[capture_context(TIME_ZONE)]
79fn generate_series_timestamptz_impl(
80    time_zone: &str,
81    start: Timestamptz,
82    stop: Timestamptz,
83    step: Interval,
84) -> Result<impl Iterator<Item = Timestamptz> + use<>> {
85    let time_zone =
86        Timestamptz::lookup_time_zone(time_zone).map_err(crate::scalar::time_zone_err)?;
87    range_generic::<_, _, _, true>(start, stop, step, time_zone)
88}
89
90#[function("range(int4, int4) -> setof int4")]
91#[function("range(int8, int8) -> setof int8")]
92fn range<T>(start: T, stop: T) -> Result<impl Iterator<Item = T>>
93where
94    T: CheckedAdd<Output = T> + PartialOrd + Copy + One + IsNegative,
95{
96    range_generic::<_, _, _, false>(start, stop, T::one(), ())
97}
98
99#[function("range(decimal, decimal) -> setof decimal")]
100fn range_decimal(start: Decimal, stop: Decimal) -> Result<impl Iterator<Item = Decimal>>
101where
102{
103    validate_range_parameters(start, stop, Decimal::one())?;
104    range_generic::<Decimal, Decimal, _, false>(start, stop, Decimal::one(), ())
105}
106
107#[function("range(int4, int4, int4) -> setof int4")]
108#[function("range(int8, int8, int8) -> setof int8")]
109#[function("range(timestamp, timestamp, interval) -> setof timestamp")]
110fn range_step<T, S>(start: T, stop: T, step: S) -> Result<impl Iterator<Item = T>>
111where
112    T: CheckedAdd<S, Output = T> + PartialOrd + Copy,
113    S: IsNegative + Copy,
114{
115    range_generic::<_, _, _, false>(start, stop, step, ())
116}
117
118#[function("range(decimal, decimal, decimal) -> setof decimal")]
119fn range_step_decimal(
120    start: Decimal,
121    stop: Decimal,
122    step: Decimal,
123) -> Result<impl Iterator<Item = Decimal>> {
124    validate_range_parameters(start, stop, step)?;
125    range_generic::<_, _, _, false>(start, stop, step, ())
126}
127
128pub trait CheckedAddWithExtra<Rhs = Self, Extra = ()> {
129    type Output;
130    fn checked_add_with_extra(self, rhs: Rhs, extra: Extra) -> Option<Self::Output>;
131}
132
133impl<L, R> CheckedAddWithExtra<R, ()> for L
134where
135    L: CheckedAdd<R>,
136{
137    type Output = L::Output;
138
139    fn checked_add_with_extra(self, rhs: R, _: ()) -> Option<Self::Output> {
140        self.checked_add(rhs)
141    }
142}
143
144impl CheckedAddWithExtra<Interval, Tz> for Timestamptz {
145    type Output = Self;
146
147    fn checked_add_with_extra(self, rhs: Interval, extra: Tz) -> Option<Self::Output> {
148        crate::scalar::timestamptz_interval_add_internal(self, rhs, extra).ok()
149    }
150}
151
152#[inline]
153fn range_generic<T, S, E, const INCLUSIVE: bool>(
154    start: T,
155    stop: T,
156    step: S,
157    extra: E,
158) -> Result<impl Iterator<Item = T>>
159where
160    T: CheckedAddWithExtra<S, E, Output = T> + PartialOrd + Copy,
161    S: IsNegative + Copy,
162    E: Copy,
163{
164    if step.is_zero() {
165        return Err(ExprError::InvalidParam {
166            name: "step",
167            reason: "step size cannot equal zero".into(),
168        });
169    }
170    let mut cur = start;
171    let neg = step.is_negative();
172    let next = move || {
173        match (INCLUSIVE, neg) {
174            (true, true) if cur < stop => return None,
175            (true, false) if cur > stop => return None,
176            (false, true) if cur <= stop => return None,
177            (false, false) if cur >= stop => return None,
178            _ => {}
179        };
180        let ret = cur;
181        cur = cur.checked_add_with_extra(step, extra)?;
182        Some(ret)
183    };
184    Ok(std::iter::from_fn(next))
185}
186
187/// Validate decimals can not be `NaN` or `infinity`.
188#[inline]
189fn validate_range_parameters(start: Decimal, stop: Decimal, step: Decimal) -> Result<()> {
190    validate_decimal(start, "start")?;
191    validate_decimal(stop, "stop")?;
192    validate_decimal(step, "step")?;
193    Ok(())
194}
195
196#[inline]
197fn validate_decimal(decimal: Decimal, name: &'static str) -> Result<()> {
198    match decimal {
199        Decimal::Normalized(_) => Ok(()),
200        Decimal::PositiveInf | Decimal::NegativeInf => Err(ExprError::InvalidParam {
201            name,
202            reason: format!("{} value cannot be infinity", name).into(),
203        }),
204        Decimal::NaN => Err(ExprError::InvalidParam {
205            name,
206            reason: format!("{} value cannot be NaN", name).into(),
207        }),
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use std::str::FromStr;
214
215    use futures_util::StreamExt;
216    use risingwave_common::array::DataChunk;
217    use risingwave_common::types::test_utils::IntervalTestExt;
218    use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp};
219    use risingwave_expr::expr::{BoxedExpression, ExpressionBoxExt, LiteralExpression};
220    use risingwave_expr::table_function::{build, check_error};
221    use risingwave_pb::expr::table_function::PbType;
222
223    const CHUNK_SIZE: usize = 1024;
224
225    #[tokio::test]
226    async fn test_generate_series_i32() {
227        generate_series_i32(2, 4, 1).await;
228        generate_series_i32(4, 2, -1).await;
229        generate_series_i32(0, 9, 2).await;
230        generate_series_i32(0, (CHUNK_SIZE * 2 + 3) as i32, 1).await;
231    }
232
233    async fn generate_series_i32(start: i32, stop: i32, step: i32) {
234        fn literal(v: i32) -> BoxedExpression {
235            LiteralExpression::new(DataType::Int32, Some(v.into())).boxed()
236        }
237        let function = build(
238            PbType::GenerateSeries,
239            DataType::Int32,
240            CHUNK_SIZE,
241            vec![literal(start), literal(stop), literal(step)],
242        )
243        .unwrap();
244        let expect_cnt = ((stop - start) / step + 1) as usize;
245
246        let dummy_chunk = DataChunk::new_dummy(1);
247        let mut actual_cnt = 0;
248        let mut output = function.eval(&dummy_chunk).await;
249        while let Some(Ok(chunk)) = output.next().await {
250            actual_cnt += chunk.cardinality();
251        }
252        assert_eq!(actual_cnt, expect_cnt);
253    }
254
255    #[tokio::test]
256    async fn test_generate_series_timestamp() {
257        let start_time = Timestamp::from_str("2008-03-01 00:00:00").unwrap();
258        let stop_time = Timestamp::from_str("2008-03-09 00:00:00").unwrap();
259        let one_minute_step = Interval::from_minutes(1);
260        let one_hour_step = Interval::from_minutes(60);
261        let one_day_step = Interval::from_days(1);
262        generate_series_timestamp(start_time, stop_time, one_minute_step, 60 * 24 * 8 + 1).await;
263        generate_series_timestamp(start_time, stop_time, one_hour_step, 24 * 8 + 1).await;
264        generate_series_timestamp(start_time, stop_time, one_day_step, 8 + 1).await;
265        generate_series_timestamp(stop_time, start_time, -one_day_step, 8 + 1).await;
266    }
267
268    async fn generate_series_timestamp(
269        start: Timestamp,
270        stop: Timestamp,
271        step: Interval,
272        expect_cnt: usize,
273    ) {
274        fn literal(ty: DataType, v: ScalarImpl) -> BoxedExpression {
275            LiteralExpression::new(ty, Some(v)).boxed()
276        }
277        let function = build(
278            PbType::GenerateSeries,
279            DataType::Timestamp,
280            CHUNK_SIZE,
281            vec![
282                literal(DataType::Timestamp, start.into()),
283                literal(DataType::Timestamp, stop.into()),
284                literal(DataType::Interval, step.into()),
285            ],
286        )
287        .unwrap();
288
289        let dummy_chunk = DataChunk::new_dummy(1);
290        let mut actual_cnt = 0;
291        let mut output = function.eval(&dummy_chunk).await;
292        while let Some(Ok(chunk)) = output.next().await {
293            actual_cnt += chunk.cardinality();
294        }
295        assert_eq!(actual_cnt, expect_cnt);
296    }
297
298    #[tokio::test]
299    async fn test_range_i32() {
300        range_i32(2, 4, 1).await;
301        range_i32(4, 2, -1).await;
302        range_i32(0, 9, 2).await;
303        range_i32(0, (CHUNK_SIZE * 2 + 3) as i32, 1).await;
304    }
305
306    async fn range_i32(start: i32, stop: i32, step: i32) {
307        fn literal(v: i32) -> BoxedExpression {
308            LiteralExpression::new(DataType::Int32, Some(v.into())).boxed()
309        }
310        let function = build(
311            PbType::Range,
312            DataType::Int32,
313            CHUNK_SIZE,
314            vec![literal(start), literal(stop), literal(step)],
315        )
316        .unwrap();
317        let expect_cnt = ((stop - start - step.signum()) / step + 1) as usize;
318
319        let dummy_chunk = DataChunk::new_dummy(1);
320        let mut actual_cnt = 0;
321        let mut output = function.eval(&dummy_chunk).await;
322        while let Some(Ok(chunk)) = output.next().await {
323            actual_cnt += chunk.cardinality();
324        }
325        assert_eq!(actual_cnt, expect_cnt);
326    }
327
328    #[tokio::test]
329    async fn test_range_timestamp() {
330        let start_time = Timestamp::from_str("2008-03-01 00:00:00").unwrap();
331        let stop_time = Timestamp::from_str("2008-03-09 00:00:00").unwrap();
332        let one_minute_step = Interval::from_minutes(1);
333        let one_hour_step = Interval::from_minutes(60);
334        let one_day_step = Interval::from_days(1);
335        range_timestamp(start_time, stop_time, one_minute_step, 60 * 24 * 8).await;
336        range_timestamp(start_time, stop_time, one_hour_step, 24 * 8).await;
337        range_timestamp(start_time, stop_time, one_day_step, 8).await;
338        range_timestamp(stop_time, start_time, -one_day_step, 8).await;
339    }
340
341    async fn range_timestamp(start: Timestamp, stop: Timestamp, step: Interval, expect_cnt: usize) {
342        fn literal(ty: DataType, v: ScalarImpl) -> BoxedExpression {
343            LiteralExpression::new(ty, Some(v)).boxed()
344        }
345        let function = build(
346            PbType::Range,
347            DataType::Timestamp,
348            CHUNK_SIZE,
349            vec![
350                literal(DataType::Timestamp, start.into()),
351                literal(DataType::Timestamp, stop.into()),
352                literal(DataType::Interval, step.into()),
353            ],
354        )
355        .unwrap();
356
357        let dummy_chunk = DataChunk::new_dummy(1);
358        let mut actual_cnt = 0;
359        let mut output = function.eval(&dummy_chunk).await;
360        while let Some(Ok(chunk)) = output.next().await {
361            actual_cnt += chunk.cardinality();
362        }
363        assert_eq!(actual_cnt, expect_cnt);
364    }
365
366    #[tokio::test]
367    async fn test_generate_series_decimal() {
368        generate_series_decimal("1", "5", "1", true).await;
369        generate_series_decimal("inf", "5", "1", false).await;
370        generate_series_decimal("inf", "-inf", "1", false).await;
371        generate_series_decimal("1", "-inf", "1", false).await;
372        generate_series_decimal("1", "5", "nan", false).await;
373        generate_series_decimal("1", "5", "inf", false).await;
374        generate_series_decimal("1", "-inf", "nan", false).await;
375    }
376
377    async fn generate_series_decimal(start: &str, stop: &str, step: &str, expect_ok: bool) {
378        fn decimal_literal(v: Decimal) -> BoxedExpression {
379            LiteralExpression::new(DataType::Decimal, Some(v.into())).boxed()
380        }
381        let function = build(
382            PbType::GenerateSeries,
383            DataType::Decimal,
384            CHUNK_SIZE,
385            vec![
386                decimal_literal(start.parse().unwrap()),
387                decimal_literal(stop.parse().unwrap()),
388                decimal_literal(step.parse().unwrap()),
389            ],
390        )
391        .unwrap();
392
393        let dummy_chunk = DataChunk::new_dummy(1);
394        let mut output = function.eval(&dummy_chunk).await;
395        while let Some(res) = output.next().await {
396            let chunk = res.unwrap();
397            let error = check_error(&chunk);
398            assert_eq!(
399                error.is_ok(),
400                expect_ok,
401                "generate_series({start}, {stop}, {step})"
402            );
403        }
404    }
405}