1use chrono_tz::Tz;
16use num_traits::One;
17use risingwave_common::types::{CheckedAdd, Decimal, Interval, IsNegative, Timestamptz};
18use risingwave_expr::expr_context::TIME_ZONE;
19use risingwave_expr::{ExprError, Result, capture_context, function};
20
21#[function("generate_series(int4, int4) -> setof int4")]
22#[function("generate_series(int8, int8) -> setof int8")]
23fn generate_series<T>(start: T, stop: T) -> Result<impl Iterator<Item = T>>
24where
25 T: CheckedAdd<Output = T> + PartialOrd + Copy + One + IsNegative,
26{
27 range_generic::<_, _, _, true>(start, stop, T::one(), ())
28}
29
30#[function("generate_series(decimal, decimal) -> setof decimal")]
31fn generate_series_decimal(start: Decimal, stop: Decimal) -> Result<impl Iterator<Item = Decimal>>
32where
33{
34 validate_range_parameters(start, stop, Decimal::one())?;
35 range_generic::<Decimal, Decimal, _, true>(start, stop, Decimal::one(), ())
36}
37
38#[function("generate_series(int4, int4, int4) -> setof int4")]
39#[function("generate_series(int8, int8, int8) -> setof int8")]
40#[function("generate_series(timestamp, timestamp, interval) -> setof timestamp")]
41fn generate_series_step<T, S>(start: T, stop: T, step: S) -> Result<impl Iterator<Item = T>>
42where
43 T: CheckedAdd<S, Output = T> + PartialOrd + Copy,
44 S: IsNegative + Copy,
45{
46 range_generic::<_, _, _, true>(start, stop, step, ())
47}
48
49#[function("generate_series(decimal, decimal, decimal) -> setof decimal")]
50fn generate_series_step_decimal(
51 start: Decimal,
52 stop: Decimal,
53 step: Decimal,
54) -> Result<impl Iterator<Item = Decimal>> {
55 validate_range_parameters(start, stop, step)?;
56 range_generic::<_, _, _, true>(start, stop, step, ())
57}
58
59#[function("generate_series(timestamptz, timestamptz, interval) -> setof timestamptz")]
60fn generate_series_timestamptz_session(
61 start: Timestamptz,
62 stop: Timestamptz,
63 step: Interval,
64) -> Result<impl Iterator<Item = Timestamptz>> {
65 generate_series_timestamptz_impl_captured(start, stop, step)
66}
67
68#[function("generate_series(timestamptz, timestamptz, interval, varchar) -> setof timestamptz")]
69fn generate_series_timestamptz_at_zone(
70 start: Timestamptz,
71 stop: Timestamptz,
72 step: Interval,
73 time_zone: &str,
74) -> Result<impl Iterator<Item = Timestamptz>> {
75 generate_series_timestamptz_impl(time_zone, start, stop, step)
76}
77
78#[capture_context(TIME_ZONE)]
79fn generate_series_timestamptz_impl(
80 time_zone: &str,
81 start: Timestamptz,
82 stop: Timestamptz,
83 step: Interval,
84) -> Result<impl Iterator<Item = Timestamptz> + use<>> {
85 let time_zone =
86 Timestamptz::lookup_time_zone(time_zone).map_err(crate::scalar::time_zone_err)?;
87 range_generic::<_, _, _, true>(start, stop, step, time_zone)
88}
89
90#[function("range(int4, int4) -> setof int4")]
91#[function("range(int8, int8) -> setof int8")]
92fn range<T>(start: T, stop: T) -> Result<impl Iterator<Item = T>>
93where
94 T: CheckedAdd<Output = T> + PartialOrd + Copy + One + IsNegative,
95{
96 range_generic::<_, _, _, false>(start, stop, T::one(), ())
97}
98
99#[function("range(decimal, decimal) -> setof decimal")]
100fn range_decimal(start: Decimal, stop: Decimal) -> Result<impl Iterator<Item = Decimal>>
101where
102{
103 validate_range_parameters(start, stop, Decimal::one())?;
104 range_generic::<Decimal, Decimal, _, false>(start, stop, Decimal::one(), ())
105}
106
107#[function("range(int4, int4, int4) -> setof int4")]
108#[function("range(int8, int8, int8) -> setof int8")]
109#[function("range(timestamp, timestamp, interval) -> setof timestamp")]
110fn range_step<T, S>(start: T, stop: T, step: S) -> Result<impl Iterator<Item = T>>
111where
112 T: CheckedAdd<S, Output = T> + PartialOrd + Copy,
113 S: IsNegative + Copy,
114{
115 range_generic::<_, _, _, false>(start, stop, step, ())
116}
117
118#[function("range(decimal, decimal, decimal) -> setof decimal")]
119fn range_step_decimal(
120 start: Decimal,
121 stop: Decimal,
122 step: Decimal,
123) -> Result<impl Iterator<Item = Decimal>> {
124 validate_range_parameters(start, stop, step)?;
125 range_generic::<_, _, _, false>(start, stop, step, ())
126}
127
128pub trait CheckedAddWithExtra<Rhs = Self, Extra = ()> {
129 type Output;
130 fn checked_add_with_extra(self, rhs: Rhs, extra: Extra) -> Option<Self::Output>;
131}
132
133impl<L, R> CheckedAddWithExtra<R, ()> for L
134where
135 L: CheckedAdd<R>,
136{
137 type Output = L::Output;
138
139 fn checked_add_with_extra(self, rhs: R, _: ()) -> Option<Self::Output> {
140 self.checked_add(rhs)
141 }
142}
143
144impl CheckedAddWithExtra<Interval, Tz> for Timestamptz {
145 type Output = Self;
146
147 fn checked_add_with_extra(self, rhs: Interval, extra: Tz) -> Option<Self::Output> {
148 crate::scalar::timestamptz_interval_add_internal(self, rhs, extra).ok()
149 }
150}
151
152#[inline]
153fn range_generic<T, S, E, const INCLUSIVE: bool>(
154 start: T,
155 stop: T,
156 step: S,
157 extra: E,
158) -> Result<impl Iterator<Item = T>>
159where
160 T: CheckedAddWithExtra<S, E, Output = T> + PartialOrd + Copy,
161 S: IsNegative + Copy,
162 E: Copy,
163{
164 if step.is_zero() {
165 return Err(ExprError::InvalidParam {
166 name: "step",
167 reason: "step size cannot equal zero".into(),
168 });
169 }
170 let mut cur = start;
171 let neg = step.is_negative();
172 let next = move || {
173 match (INCLUSIVE, neg) {
174 (true, true) if cur < stop => return None,
175 (true, false) if cur > stop => return None,
176 (false, true) if cur <= stop => return None,
177 (false, false) if cur >= stop => return None,
178 _ => {}
179 };
180 let ret = cur;
181 cur = cur.checked_add_with_extra(step, extra)?;
182 Some(ret)
183 };
184 Ok(std::iter::from_fn(next))
185}
186
187#[inline]
189fn validate_range_parameters(start: Decimal, stop: Decimal, step: Decimal) -> Result<()> {
190 validate_decimal(start, "start")?;
191 validate_decimal(stop, "stop")?;
192 validate_decimal(step, "step")?;
193 Ok(())
194}
195
196#[inline]
197fn validate_decimal(decimal: Decimal, name: &'static str) -> Result<()> {
198 match decimal {
199 Decimal::Normalized(_) => Ok(()),
200 Decimal::PositiveInf | Decimal::NegativeInf => Err(ExprError::InvalidParam {
201 name,
202 reason: format!("{} value cannot be infinity", name).into(),
203 }),
204 Decimal::NaN => Err(ExprError::InvalidParam {
205 name,
206 reason: format!("{} value cannot be NaN", name).into(),
207 }),
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use std::str::FromStr;
214
215 use futures_util::StreamExt;
216 use risingwave_common::array::DataChunk;
217 use risingwave_common::types::test_utils::IntervalTestExt;
218 use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp};
219 use risingwave_expr::expr::{BoxedExpression, ExpressionBoxExt, LiteralExpression};
220 use risingwave_expr::table_function::{build, check_error};
221 use risingwave_pb::expr::table_function::PbType;
222
223 const CHUNK_SIZE: usize = 1024;
224
225 #[tokio::test]
226 async fn test_generate_series_i32() {
227 generate_series_i32(2, 4, 1).await;
228 generate_series_i32(4, 2, -1).await;
229 generate_series_i32(0, 9, 2).await;
230 generate_series_i32(0, (CHUNK_SIZE * 2 + 3) as i32, 1).await;
231 }
232
233 async fn generate_series_i32(start: i32, stop: i32, step: i32) {
234 fn literal(v: i32) -> BoxedExpression {
235 LiteralExpression::new(DataType::Int32, Some(v.into())).boxed()
236 }
237 let function = build(
238 PbType::GenerateSeries,
239 DataType::Int32,
240 CHUNK_SIZE,
241 vec![literal(start), literal(stop), literal(step)],
242 )
243 .unwrap();
244 let expect_cnt = ((stop - start) / step + 1) as usize;
245
246 let dummy_chunk = DataChunk::new_dummy(1);
247 let mut actual_cnt = 0;
248 let mut output = function.eval(&dummy_chunk).await;
249 while let Some(Ok(chunk)) = output.next().await {
250 actual_cnt += chunk.cardinality();
251 }
252 assert_eq!(actual_cnt, expect_cnt);
253 }
254
255 #[tokio::test]
256 async fn test_generate_series_timestamp() {
257 let start_time = Timestamp::from_str("2008-03-01 00:00:00").unwrap();
258 let stop_time = Timestamp::from_str("2008-03-09 00:00:00").unwrap();
259 let one_minute_step = Interval::from_minutes(1);
260 let one_hour_step = Interval::from_minutes(60);
261 let one_day_step = Interval::from_days(1);
262 generate_series_timestamp(start_time, stop_time, one_minute_step, 60 * 24 * 8 + 1).await;
263 generate_series_timestamp(start_time, stop_time, one_hour_step, 24 * 8 + 1).await;
264 generate_series_timestamp(start_time, stop_time, one_day_step, 8 + 1).await;
265 generate_series_timestamp(stop_time, start_time, -one_day_step, 8 + 1).await;
266 }
267
268 async fn generate_series_timestamp(
269 start: Timestamp,
270 stop: Timestamp,
271 step: Interval,
272 expect_cnt: usize,
273 ) {
274 fn literal(ty: DataType, v: ScalarImpl) -> BoxedExpression {
275 LiteralExpression::new(ty, Some(v)).boxed()
276 }
277 let function = build(
278 PbType::GenerateSeries,
279 DataType::Timestamp,
280 CHUNK_SIZE,
281 vec![
282 literal(DataType::Timestamp, start.into()),
283 literal(DataType::Timestamp, stop.into()),
284 literal(DataType::Interval, step.into()),
285 ],
286 )
287 .unwrap();
288
289 let dummy_chunk = DataChunk::new_dummy(1);
290 let mut actual_cnt = 0;
291 let mut output = function.eval(&dummy_chunk).await;
292 while let Some(Ok(chunk)) = output.next().await {
293 actual_cnt += chunk.cardinality();
294 }
295 assert_eq!(actual_cnt, expect_cnt);
296 }
297
298 #[tokio::test]
299 async fn test_range_i32() {
300 range_i32(2, 4, 1).await;
301 range_i32(4, 2, -1).await;
302 range_i32(0, 9, 2).await;
303 range_i32(0, (CHUNK_SIZE * 2 + 3) as i32, 1).await;
304 }
305
306 async fn range_i32(start: i32, stop: i32, step: i32) {
307 fn literal(v: i32) -> BoxedExpression {
308 LiteralExpression::new(DataType::Int32, Some(v.into())).boxed()
309 }
310 let function = build(
311 PbType::Range,
312 DataType::Int32,
313 CHUNK_SIZE,
314 vec![literal(start), literal(stop), literal(step)],
315 )
316 .unwrap();
317 let expect_cnt = ((stop - start - step.signum()) / step + 1) as usize;
318
319 let dummy_chunk = DataChunk::new_dummy(1);
320 let mut actual_cnt = 0;
321 let mut output = function.eval(&dummy_chunk).await;
322 while let Some(Ok(chunk)) = output.next().await {
323 actual_cnt += chunk.cardinality();
324 }
325 assert_eq!(actual_cnt, expect_cnt);
326 }
327
328 #[tokio::test]
329 async fn test_range_timestamp() {
330 let start_time = Timestamp::from_str("2008-03-01 00:00:00").unwrap();
331 let stop_time = Timestamp::from_str("2008-03-09 00:00:00").unwrap();
332 let one_minute_step = Interval::from_minutes(1);
333 let one_hour_step = Interval::from_minutes(60);
334 let one_day_step = Interval::from_days(1);
335 range_timestamp(start_time, stop_time, one_minute_step, 60 * 24 * 8).await;
336 range_timestamp(start_time, stop_time, one_hour_step, 24 * 8).await;
337 range_timestamp(start_time, stop_time, one_day_step, 8).await;
338 range_timestamp(stop_time, start_time, -one_day_step, 8).await;
339 }
340
341 async fn range_timestamp(start: Timestamp, stop: Timestamp, step: Interval, expect_cnt: usize) {
342 fn literal(ty: DataType, v: ScalarImpl) -> BoxedExpression {
343 LiteralExpression::new(ty, Some(v)).boxed()
344 }
345 let function = build(
346 PbType::Range,
347 DataType::Timestamp,
348 CHUNK_SIZE,
349 vec![
350 literal(DataType::Timestamp, start.into()),
351 literal(DataType::Timestamp, stop.into()),
352 literal(DataType::Interval, step.into()),
353 ],
354 )
355 .unwrap();
356
357 let dummy_chunk = DataChunk::new_dummy(1);
358 let mut actual_cnt = 0;
359 let mut output = function.eval(&dummy_chunk).await;
360 while let Some(Ok(chunk)) = output.next().await {
361 actual_cnt += chunk.cardinality();
362 }
363 assert_eq!(actual_cnt, expect_cnt);
364 }
365
366 #[tokio::test]
367 async fn test_generate_series_decimal() {
368 generate_series_decimal("1", "5", "1", true).await;
369 generate_series_decimal("inf", "5", "1", false).await;
370 generate_series_decimal("inf", "-inf", "1", false).await;
371 generate_series_decimal("1", "-inf", "1", false).await;
372 generate_series_decimal("1", "5", "nan", false).await;
373 generate_series_decimal("1", "5", "inf", false).await;
374 generate_series_decimal("1", "-inf", "nan", false).await;
375 }
376
377 async fn generate_series_decimal(start: &str, stop: &str, step: &str, expect_ok: bool) {
378 fn decimal_literal(v: Decimal) -> BoxedExpression {
379 LiteralExpression::new(DataType::Decimal, Some(v.into())).boxed()
380 }
381 let function = build(
382 PbType::GenerateSeries,
383 DataType::Decimal,
384 CHUNK_SIZE,
385 vec![
386 decimal_literal(start.parse().unwrap()),
387 decimal_literal(stop.parse().unwrap()),
388 decimal_literal(step.parse().unwrap()),
389 ],
390 )
391 .unwrap();
392
393 let dummy_chunk = DataChunk::new_dummy(1);
394 let mut output = function.eval(&dummy_chunk).await;
395 while let Some(res) = output.next().await {
396 let chunk = res.unwrap();
397 let error = check_error(&chunk);
398 assert_eq!(
399 error.is_ok(),
400 expect_ok,
401 "generate_series({start}, {stop}, {step})"
402 );
403 }
404 }
405}