risingwave_stream/from_proto/
temporal_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::catalog::ColumnId;
18use risingwave_common::hash::{HashKey, HashKeyDispatcher};
19use risingwave_common::types::DataType;
20use risingwave_expr::expr::{NonStrictExpression, build_non_strict_from_prost};
21use risingwave_pb::plan_common::{JoinType as JoinTypeProto, StorageTableDesc};
22use risingwave_storage::table::batch_table::BatchTable;
23
24use super::*;
25use crate::common::table::state_table::StateTable;
26use crate::executor::monitor::StreamingMetrics;
27use crate::executor::{
28    ActorContextRef, JoinType, NestedLoopTemporalJoinExecutor, TemporalJoinExecutor,
29};
30use crate::task::AtomicU64Ref;
31
32pub struct TemporalJoinExecutorBuilder;
33
34impl ExecutorBuilder for TemporalJoinExecutorBuilder {
35    type Node = TemporalJoinNode;
36
37    async fn new_boxed_executor(
38        params: ExecutorParams,
39        node: &Self::Node,
40        store: impl StateStore,
41    ) -> StreamResult<Executor> {
42        let table_desc: &StorageTableDesc = node.get_table_desc()?;
43        let condition = match node.get_condition() {
44            Ok(cond_prost) => Some(build_non_strict_from_prost(
45                cond_prost,
46                params.eval_error_report,
47            )?),
48            Err(_) => None,
49        };
50
51        let table_output_indices = node
52            .get_table_output_indices()
53            .iter()
54            .map(|&x| x as usize)
55            .collect_vec();
56
57        let output_indices = node
58            .get_output_indices()
59            .iter()
60            .map(|&x| x as usize)
61            .collect_vec();
62        let [source_l, source_r]: [_; 2] = params.input.try_into().unwrap();
63
64        if node.get_is_nested_loop() {
65            let right_table = BatchTable::new_partial(
66                store.clone(),
67                table_output_indices
68                    .iter()
69                    .map(|&x| ColumnId::new(table_desc.columns[x].column_id))
70                    .collect_vec(),
71                params.vnode_bitmap.clone().map(Into::into),
72                table_desc,
73            );
74
75            let dispatcher_args = NestedLoopTemporalJoinExecutorDispatcherArgs {
76                ctx: params.actor_context,
77                info: params.info.clone(),
78                left: source_l,
79                right: source_r,
80                right_table,
81                condition,
82                output_indices,
83                chunk_size: params.env.config().developer.chunk_size,
84                metrics: params.executor_stats,
85                join_type_proto: node.get_join_type()?,
86            };
87            Ok((params.info, dispatcher_args.dispatch()?).into())
88        } else {
89            let table = {
90                let column_ids = table_desc
91                    .columns
92                    .iter()
93                    .map(|x| ColumnId::new(x.column_id))
94                    .collect_vec();
95
96                BatchTable::new_partial(
97                    store.clone(),
98                    column_ids,
99                    params.vnode_bitmap.clone().map(Into::into),
100                    table_desc,
101                )
102            };
103
104            let table_stream_key_indices = table_desc
105                .stream_key
106                .iter()
107                .map(|&k| k as usize)
108                .collect_vec();
109
110            let left_join_keys = node
111                .get_left_key()
112                .iter()
113                .map(|key| *key as usize)
114                .collect_vec();
115
116            let right_join_keys = node
117                .get_right_key()
118                .iter()
119                .map(|key| *key as usize)
120                .collect_vec();
121
122            let null_safe = node.get_null_safe().to_vec();
123
124            let join_key_data_types = left_join_keys
125                .iter()
126                .map(|idx| source_l.schema().fields[*idx].data_type())
127                .collect_vec();
128
129            let memo_table = node.get_memo_table();
130            let memo_table = match memo_table {
131                Ok(memo_table) => {
132                    let vnodes = Arc::new(
133                        params
134                            .vnode_bitmap
135                            .expect("vnodes not set for temporal join"),
136                    );
137                    Some(
138                        StateTable::from_table_catalog(
139                            memo_table,
140                            store.clone(),
141                            Some(vnodes.clone()),
142                        )
143                        .await,
144                    )
145                }
146                Err(_) => None,
147            };
148            let append_only = memo_table.is_none();
149
150            let dispatcher_args = TemporalJoinExecutorDispatcherArgs {
151                ctx: params.actor_context,
152                info: params.info.clone(),
153                left: source_l,
154                right: source_r,
155                right_table: table,
156                left_join_keys,
157                right_join_keys,
158                null_safe,
159                condition,
160                output_indices,
161                table_output_indices,
162                table_stream_key_indices,
163                watermark_epoch: params.watermark_epoch,
164                chunk_size: params.env.config().developer.chunk_size,
165                metrics: params.executor_stats,
166                join_type_proto: node.get_join_type()?,
167                join_key_data_types,
168                memo_table,
169                append_only,
170            };
171
172            Ok((params.info, dispatcher_args.dispatch()?).into())
173        }
174    }
175}
176
177struct TemporalJoinExecutorDispatcherArgs<S: StateStore> {
178    ctx: ActorContextRef,
179    info: ExecutorInfo,
180    left: Executor,
181    right: Executor,
182    right_table: BatchTable<S>,
183    left_join_keys: Vec<usize>,
184    right_join_keys: Vec<usize>,
185    null_safe: Vec<bool>,
186    condition: Option<NonStrictExpression>,
187    output_indices: Vec<usize>,
188    table_output_indices: Vec<usize>,
189    table_stream_key_indices: Vec<usize>,
190    watermark_epoch: AtomicU64Ref,
191    chunk_size: usize,
192    metrics: Arc<StreamingMetrics>,
193    join_type_proto: JoinTypeProto,
194    join_key_data_types: Vec<DataType>,
195    memo_table: Option<StateTable<S>>,
196    append_only: bool,
197}
198
199impl<S: StateStore> HashKeyDispatcher for TemporalJoinExecutorDispatcherArgs<S> {
200    type Output = StreamResult<Box<dyn Execute>>;
201
202    fn dispatch_impl<K: HashKey>(self) -> Self::Output {
203        /// This macro helps to fill the const generic type parameter.
204        macro_rules! build {
205            ($join_type:ident, $append_only:ident) => {
206                Ok(Box::new(TemporalJoinExecutor::<
207                    K,
208                    S,
209                    { JoinType::$join_type },
210                    { $append_only },
211                >::new(
212                    self.ctx,
213                    self.info,
214                    self.left,
215                    self.right,
216                    self.right_table,
217                    self.left_join_keys,
218                    self.right_join_keys,
219                    self.null_safe,
220                    self.condition,
221                    self.output_indices,
222                    self.table_output_indices,
223                    self.table_stream_key_indices,
224                    self.watermark_epoch,
225                    self.metrics,
226                    self.chunk_size,
227                    self.join_key_data_types,
228                    self.memo_table,
229                )))
230            };
231        }
232        match self.join_type_proto {
233            JoinTypeProto::Inner => {
234                if self.append_only {
235                    build!(Inner, true)
236                } else {
237                    build!(Inner, false)
238                }
239            }
240            JoinTypeProto::LeftOuter => {
241                if self.append_only {
242                    build!(LeftOuter, true)
243                } else {
244                    build!(LeftOuter, false)
245                }
246            }
247            _ => unreachable!(),
248        }
249    }
250
251    fn data_types(&self) -> &[DataType] {
252        &self.join_key_data_types
253    }
254}
255
256struct NestedLoopTemporalJoinExecutorDispatcherArgs<S: StateStore> {
257    ctx: ActorContextRef,
258    info: ExecutorInfo,
259    left: Executor,
260    right: Executor,
261    right_table: BatchTable<S>,
262    condition: Option<NonStrictExpression>,
263    output_indices: Vec<usize>,
264    chunk_size: usize,
265    metrics: Arc<StreamingMetrics>,
266    join_type_proto: JoinTypeProto,
267}
268
269impl<S: StateStore> NestedLoopTemporalJoinExecutorDispatcherArgs<S> {
270    fn dispatch(self) -> StreamResult<Box<dyn Execute>> {
271        /// This macro helps to fill the const generic type parameter.
272        macro_rules! build {
273            ($join_type:ident) => {
274                Ok(Box::new(NestedLoopTemporalJoinExecutor::<
275                    S,
276                    { JoinType::$join_type },
277                >::new(
278                    self.ctx,
279                    self.info,
280                    self.left,
281                    self.right,
282                    self.right_table,
283                    self.condition,
284                    self.output_indices,
285                    self.metrics,
286                    self.chunk_size,
287                )))
288            };
289        }
290        match self.join_type_proto {
291            JoinTypeProto::Inner => {
292                build!(Inner)
293            }
294            JoinTypeProto::LeftOuter => {
295                build!(LeftOuter)
296            }
297            _ => unreachable!(),
298        }
299    }
300}