risingwave_expr/window_function/
session.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::ops::Deref;
17use std::sync::Arc;
18
19use anyhow::Context;
20use educe::Educe;
21use futures::FutureExt;
22use risingwave_common::bail;
23use risingwave_common::row::OwnedRow;
24use risingwave_common::types::{
25    DataType, Datum, IsNegative, ScalarImpl, ScalarRefImpl, ToOwnedDatum, ToText,
26};
27use risingwave_common::util::sort_util::OrderType;
28use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt};
29use risingwave_pb::expr::window_frame::PbSessionFrameBounds;
30
31use super::FrameBoundsImpl;
32use crate::Result;
33use crate::expr::{
34    BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression,
35    build_func,
36};
37
38/// To implement Session Window in a similar way to Range Frame, we define a similar frame bounds
39/// structure here. It's very like [`RangeFrameBounds`](super::RangeFrameBounds), but with a gap
40/// instead of start & end offset.
41#[derive(Debug, Clone, Eq, PartialEq, Hash)]
42pub struct SessionFrameBounds {
43    pub order_data_type: DataType,
44    pub order_type: OrderType,
45    pub gap_data_type: DataType,
46    pub gap: SessionFrameGap,
47}
48
49impl SessionFrameBounds {
50    pub(super) fn from_protobuf(bounds: &PbSessionFrameBounds) -> Result<Self> {
51        let order_data_type = DataType::from(bounds.get_order_data_type()?);
52        let order_type = OrderType::from_protobuf(bounds.get_order_type()?);
53        let gap_data_type = DataType::from(bounds.get_gap_data_type()?);
54        let gap_value = Datum::from_protobuf(bounds.get_gap()?, &gap_data_type)
55            .context("gap `Datum` is not decodable")?
56            .context("gap of session frame must be non-NULL")?;
57        let mut gap = SessionFrameGap::new(gap_value);
58        gap.prepare(&order_data_type, &gap_data_type)?;
59        Ok(Self {
60            order_data_type,
61            order_type,
62            gap_data_type,
63            gap,
64        })
65    }
66
67    pub(super) fn to_protobuf(&self) -> PbSessionFrameBounds {
68        PbSessionFrameBounds {
69            gap: Some(Some(self.gap.as_scalar_ref_impl()).to_protobuf()),
70            order_data_type: Some(self.order_data_type.to_protobuf()),
71            order_type: Some(self.order_type.to_protobuf()),
72            gap_data_type: Some(self.gap_data_type.to_protobuf()),
73        }
74    }
75}
76
77impl Display for SessionFrameBounds {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        write!(
80            f,
81            "SESSION WITH GAP {}",
82            self.gap.as_scalar_ref_impl().to_text()
83        )
84    }
85}
86
87impl FrameBoundsImpl for SessionFrameBounds {
88    fn validate(&self) -> Result<()> {
89        // TODO(rc): maybe can merge with `RangeFrameBounds::validate`
90
91        fn validate_non_negative(val: impl IsNegative + Display) -> Result<()> {
92            if val.is_negative() {
93                bail!("session gap should be non-negative, but {} is given", val);
94            }
95            Ok(())
96        }
97
98        match self.gap.as_scalar_ref_impl() {
99            ScalarRefImpl::Int16(val) => validate_non_negative(val)?,
100            ScalarRefImpl::Int32(val) => validate_non_negative(val)?,
101            ScalarRefImpl::Int64(val) => validate_non_negative(val)?,
102            ScalarRefImpl::Float32(val) => validate_non_negative(val)?,
103            ScalarRefImpl::Float64(val) => validate_non_negative(val)?,
104            ScalarRefImpl::Decimal(val) => validate_non_negative(val)?,
105            ScalarRefImpl::Interval(val) => {
106                if !val.is_never_negative() {
107                    bail!(
108                        "for session gap of type `interval`, each field should be non-negative, but {} is given",
109                        val
110                    );
111                }
112                if matches!(self.order_data_type, DataType::Timestamptz) {
113                    // for `timestamptz`, we only support gap without `month` and `day` fields
114                    if val.months() != 0 || val.days() != 0 {
115                        bail!(
116                            "for session order column of type `timestamptz`, gap should not have non-zero `month` and `day`",
117                        );
118                    }
119                }
120            }
121            _ => unreachable!(
122                "other order column data types are not supported and should be banned in frontend"
123            ),
124        }
125        Ok(())
126    }
127}
128
129impl SessionFrameBounds {
130    pub fn minimal_next_start_of(&self, end_order_value: impl ToOwnedDatum) -> Datum {
131        self.gap.for_calc().minimal_next_start_of(end_order_value)
132    }
133}
134
135/// The wrapper type for [`ScalarImpl`] session gap, containing an expression to help adding the gap
136/// to a given value.
137#[derive(Debug, Clone, Educe)]
138#[educe(PartialEq, Eq, Hash)]
139pub struct SessionFrameGap {
140    /// The original gap value.
141    gap: ScalarImpl,
142    /// Built expression for `$0 + gap`.
143    #[educe(PartialEq(ignore), Hash(ignore))]
144    add_expr: Option<Arc<BoxedExpression>>,
145}
146
147impl Deref for SessionFrameGap {
148    type Target = ScalarImpl;
149
150    fn deref(&self) -> &Self::Target {
151        &self.gap
152    }
153}
154
155impl SessionFrameGap {
156    pub fn new(gap: ScalarImpl) -> Self {
157        Self {
158            gap,
159            add_expr: None,
160        }
161    }
162
163    fn prepare(&mut self, order_data_type: &DataType, gap_data_type: &DataType) -> Result<()> {
164        use risingwave_pb::expr::expr_node::PbType as PbExprType;
165
166        let input_expr = InputRefExpression::new(order_data_type.clone(), 0);
167        let gap_expr = LiteralExpression::new(gap_data_type.clone(), Some(self.gap.clone()));
168        self.add_expr = Some(Arc::new(build_func(
169            PbExprType::Add,
170            order_data_type.clone(),
171            vec![input_expr.clone().boxed(), gap_expr.clone().boxed()],
172        )?));
173        Ok(())
174    }
175
176    pub fn new_for_test(
177        gap: ScalarImpl,
178        order_data_type: &DataType,
179        gap_data_type: &DataType,
180    ) -> Self {
181        let mut gap = Self::new(gap);
182        gap.prepare(order_data_type, gap_data_type).unwrap();
183        gap
184    }
185
186    fn for_calc(&self) -> SessionFrameGapRef<'_> {
187        SessionFrameGapRef {
188            add_expr: self.add_expr.as_ref().unwrap().as_ref(),
189        }
190    }
191}
192
193#[derive(Debug, Educe)]
194#[educe(Clone, Copy)]
195struct SessionFrameGapRef<'a> {
196    add_expr: &'a dyn Expression,
197}
198
199impl SessionFrameGapRef<'_> {
200    fn minimal_next_start_of(&self, end_order_value: impl ToOwnedDatum) -> Datum {
201        let row = OwnedRow::new(vec![end_order_value.to_owned_datum()]);
202        self.add_expr
203            .eval_row(&row)
204            .now_or_never()
205            .expect("frame bound calculation should finish immediately")
206            .expect("just simple calculation, should succeed") // TODO(rc): handle overflow
207    }
208}