risingwave_stream/from_proto/
temporal_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use risingwave_common::catalog::ColumnId;
18use risingwave_common::hash::{HashKey, HashKeyDispatcher};
19use risingwave_common::types::DataType;
20use risingwave_expr::expr::{NonStrictExpression, build_non_strict_from_prost};
21use risingwave_pb::plan_common::{JoinType as JoinTypeProto, StorageTableDesc};
22use risingwave_storage::table::batch_table::BatchTable;
23
24use super::*;
25use crate::common::table::state_table::{StateTable, StateTableBuilder};
26use crate::executor::monitor::StreamingMetrics;
27use crate::executor::{
28    ActorContextRef, JoinType, NestedLoopTemporalJoinExecutor, TemporalJoinExecutor,
29};
30use crate::task::AtomicU64Ref;
31
32pub struct TemporalJoinExecutorBuilder;
33
34impl ExecutorBuilder for TemporalJoinExecutorBuilder {
35    type Node = TemporalJoinNode;
36
37    async fn new_boxed_executor(
38        params: ExecutorParams,
39        node: &Self::Node,
40        store: impl StateStore,
41    ) -> StreamResult<Executor> {
42        let table_desc: &StorageTableDesc = node.get_table_desc()?;
43        let condition = match node.get_condition() {
44            Ok(cond_prost) => Some(build_non_strict_from_prost(
45                cond_prost,
46                params.eval_error_report,
47            )?),
48            Err(_) => None,
49        };
50
51        let table_output_indices = node
52            .get_table_output_indices()
53            .iter()
54            .map(|&x| x as usize)
55            .collect_vec();
56
57        let output_indices = node
58            .get_output_indices()
59            .iter()
60            .map(|&x| x as usize)
61            .collect_vec();
62        let [source_l, source_r]: [_; 2] = params.input.try_into().unwrap();
63
64        if node.get_is_nested_loop() {
65            let right_table = BatchTable::new_partial(
66                store.clone(),
67                table_output_indices
68                    .iter()
69                    .map(|&x| ColumnId::new(table_desc.columns[x].column_id))
70                    .collect_vec(),
71                params.vnode_bitmap.clone().map(Into::into),
72                table_desc,
73            );
74
75            let dispatcher_args = NestedLoopTemporalJoinExecutorDispatcherArgs {
76                ctx: params.actor_context,
77                info: params.info.clone(),
78                left: source_l,
79                right: source_r,
80                right_table,
81                condition,
82                output_indices,
83                chunk_size: params.config.developer.chunk_size,
84                metrics: params.executor_stats,
85                join_type_proto: node.get_join_type()?,
86            };
87            Ok((params.info, dispatcher_args.dispatch()?).into())
88        } else {
89            let table = {
90                let column_ids = table_desc
91                    .columns
92                    .iter()
93                    .map(|x| ColumnId::new(x.column_id))
94                    .collect_vec();
95
96                BatchTable::new_partial(
97                    store.clone(),
98                    column_ids,
99                    params.vnode_bitmap.clone().map(Into::into),
100                    table_desc,
101                )
102            };
103
104            let table_stream_key_indices = table_desc
105                .stream_key
106                .iter()
107                .map(|&k| k as usize)
108                .collect_vec();
109
110            let left_join_keys = node
111                .get_left_key()
112                .iter()
113                .map(|key| *key as usize)
114                .collect_vec();
115
116            let right_join_keys = node
117                .get_right_key()
118                .iter()
119                .map(|key| *key as usize)
120                .collect_vec();
121
122            let null_safe = node.get_null_safe().clone();
123
124            let join_key_data_types = left_join_keys
125                .iter()
126                .map(|idx| source_l.schema().fields[*idx].data_type())
127                .collect_vec();
128
129            let memo_table = node.get_memo_table();
130            let memo_table = match memo_table {
131                Ok(memo_table) => {
132                    let vnodes = Arc::new(
133                        params
134                            .vnode_bitmap
135                            .expect("vnodes not set for temporal join"),
136                    );
137                    Some(
138                        StateTableBuilder::new(memo_table, store.clone(), Some(vnodes.clone()))
139                            .enable_preload_all_rows_by_config(&params.config)
140                            .build()
141                            .await,
142                    )
143                }
144                Err(_) => None,
145            };
146            let append_only = memo_table.is_none();
147
148            let dispatcher_args = TemporalJoinExecutorDispatcherArgs {
149                ctx: params.actor_context,
150                info: params.info.clone(),
151                left: source_l,
152                right: source_r,
153                right_table: table,
154                left_join_keys,
155                right_join_keys,
156                null_safe,
157                condition,
158                output_indices,
159                table_output_indices,
160                table_stream_key_indices,
161                watermark_epoch: params.watermark_epoch,
162                chunk_size: params.config.developer.chunk_size,
163                metrics: params.executor_stats,
164                join_type_proto: node.get_join_type()?,
165                join_key_data_types,
166                memo_table,
167                append_only,
168            };
169
170            Ok((params.info, dispatcher_args.dispatch()?).into())
171        }
172    }
173}
174
175struct TemporalJoinExecutorDispatcherArgs<S: StateStore> {
176    ctx: ActorContextRef,
177    info: ExecutorInfo,
178    left: Executor,
179    right: Executor,
180    right_table: BatchTable<S>,
181    left_join_keys: Vec<usize>,
182    right_join_keys: Vec<usize>,
183    null_safe: Vec<bool>,
184    condition: Option<NonStrictExpression>,
185    output_indices: Vec<usize>,
186    table_output_indices: Vec<usize>,
187    table_stream_key_indices: Vec<usize>,
188    watermark_epoch: AtomicU64Ref,
189    chunk_size: usize,
190    metrics: Arc<StreamingMetrics>,
191    join_type_proto: JoinTypeProto,
192    join_key_data_types: Vec<DataType>,
193    memo_table: Option<StateTable<S>>,
194    append_only: bool,
195}
196
197impl<S: StateStore> HashKeyDispatcher for TemporalJoinExecutorDispatcherArgs<S> {
198    type Output = StreamResult<Box<dyn Execute>>;
199
200    fn dispatch_impl<K: HashKey>(self) -> Self::Output {
201        /// This macro helps to fill the const generic type parameter.
202        macro_rules! build {
203            ($join_type:ident, $append_only:ident) => {
204                Ok(Box::new(TemporalJoinExecutor::<
205                    K,
206                    S,
207                    { JoinType::$join_type },
208                    { $append_only },
209                >::new(
210                    self.ctx,
211                    self.info,
212                    self.left,
213                    self.right,
214                    self.right_table,
215                    self.left_join_keys,
216                    self.right_join_keys,
217                    self.null_safe,
218                    self.condition,
219                    self.output_indices,
220                    self.table_output_indices,
221                    self.table_stream_key_indices,
222                    self.watermark_epoch,
223                    self.metrics,
224                    self.chunk_size,
225                    self.join_key_data_types,
226                    self.memo_table,
227                )))
228            };
229        }
230        match self.join_type_proto {
231            JoinTypeProto::Inner => {
232                if self.append_only {
233                    build!(Inner, true)
234                } else {
235                    build!(Inner, false)
236                }
237            }
238            JoinTypeProto::LeftOuter => {
239                if self.append_only {
240                    build!(LeftOuter, true)
241                } else {
242                    build!(LeftOuter, false)
243                }
244            }
245            _ => unreachable!(),
246        }
247    }
248
249    fn data_types(&self) -> &[DataType] {
250        &self.join_key_data_types
251    }
252}
253
254struct NestedLoopTemporalJoinExecutorDispatcherArgs<S: StateStore> {
255    ctx: ActorContextRef,
256    info: ExecutorInfo,
257    left: Executor,
258    right: Executor,
259    right_table: BatchTable<S>,
260    condition: Option<NonStrictExpression>,
261    output_indices: Vec<usize>,
262    chunk_size: usize,
263    metrics: Arc<StreamingMetrics>,
264    join_type_proto: JoinTypeProto,
265}
266
267impl<S: StateStore> NestedLoopTemporalJoinExecutorDispatcherArgs<S> {
268    fn dispatch(self) -> StreamResult<Box<dyn Execute>> {
269        /// This macro helps to fill the const generic type parameter.
270        macro_rules! build {
271            ($join_type:ident) => {
272                Ok(Box::new(NestedLoopTemporalJoinExecutor::<
273                    S,
274                    { JoinType::$join_type },
275                >::new(
276                    self.ctx,
277                    self.info,
278                    self.left,
279                    self.right,
280                    self.right_table,
281                    self.condition,
282                    self.output_indices,
283                    self.metrics,
284                    self.chunk_size,
285                )))
286            };
287        }
288        match self.join_type_proto {
289            JoinTypeProto::Inner => {
290                build!(Inner)
291            }
292            JoinTypeProto::LeftOuter => {
293                build!(LeftOuter)
294            }
295            _ => unreachable!(),
296        }
297    }
298}