risingwave_frontend/stream_fragmenter/rewrite/
delta_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::rc::Rc;
16
17use itertools::Itertools;
18use risingwave_pb::plan_common::PbField;
19use risingwave_pb::stream_plan::lookup_node::ArrangementTableId;
20use risingwave_pb::stream_plan::stream_node::NodeBody;
21use risingwave_pb::stream_plan::{
22    DispatchStrategy, DispatcherType, ExchangeNode, LookupNode, LookupUnionNode, StreamNode,
23};
24
25use super::super::{BuildFragmentGraphState, StreamFragment, StreamFragmentEdge};
26use crate::error::Result;
27use crate::stream_fragmenter::build_and_add_fragment;
28
29fn build_no_shuffle_exchange_for_delta_join(
30    state: &mut BuildFragmentGraphState,
31    upstream: &StreamNode,
32) -> StreamNode {
33    StreamNode {
34        operator_id: state.gen_operator_id() as u64,
35        identity: "NO SHUFFLE Exchange (Lookup and Merge)".into(),
36        fields: upstream.fields.clone(),
37        stream_key: upstream.stream_key.clone(),
38        node_body: Some(NodeBody::Exchange(Box::new(ExchangeNode {
39            strategy: Some(dispatch_no_shuffle(
40                (0..(upstream.fields.len() as u32)).collect(),
41            )),
42        }))),
43        input: vec![],
44        append_only: upstream.append_only,
45    }
46}
47
48fn build_consistent_hash_shuffle_exchange_for_delta_join(
49    state: &mut BuildFragmentGraphState,
50    upstream: &StreamNode,
51    dist_key_indices: Vec<u32>,
52) -> StreamNode {
53    StreamNode {
54        operator_id: state.gen_operator_id() as u64,
55        identity: "HASH Exchange (Lookup and Merge)".into(),
56        fields: upstream.fields.clone(),
57        stream_key: upstream.stream_key.clone(),
58        node_body: Some(NodeBody::Exchange(Box::new(ExchangeNode {
59            strategy: Some(dispatch_consistent_hash_shuffle(
60                dist_key_indices,
61                (0..(upstream.fields.len() as u32)).collect(),
62            )),
63        }))),
64        input: vec![],
65        append_only: upstream.append_only,
66    }
67}
68
69fn dispatch_no_shuffle(output_indices: Vec<u32>) -> DispatchStrategy {
70    DispatchStrategy {
71        r#type: DispatcherType::NoShuffle.into(),
72        dist_key_indices: vec![],
73        output_indices,
74    }
75}
76
77fn dispatch_consistent_hash_shuffle(
78    dist_key_indices: Vec<u32>,
79    output_indices: Vec<u32>,
80) -> DispatchStrategy {
81    // Actually Hash shuffle is consistent hash shuffle now.
82    DispatchStrategy {
83        r#type: DispatcherType::Hash.into(),
84        dist_key_indices,
85        output_indices,
86    }
87}
88
89fn build_lookup_for_delta_join(
90    state: &mut BuildFragmentGraphState,
91    (exchange_node_arrangement, exchange_node_stream): (&StreamNode, &StreamNode),
92    (output_fields, output_stream_key): (Vec<PbField>, Vec<u32>),
93    lookup_node: LookupNode,
94) -> StreamNode {
95    StreamNode {
96        operator_id: state.gen_operator_id() as u64,
97        identity: "Lookup".into(),
98        fields: output_fields,
99        stream_key: output_stream_key,
100        node_body: Some(NodeBody::Lookup(Box::new(lookup_node))),
101        input: vec![
102            exchange_node_arrangement.clone(),
103            exchange_node_stream.clone(),
104        ],
105        append_only: exchange_node_stream.append_only,
106    }
107}
108
109fn build_delta_join_inner(
110    state: &mut BuildFragmentGraphState,
111    current_fragment: &StreamFragment,
112    arrange_0_frag: Rc<StreamFragment>,
113    arrange_1_frag: Rc<StreamFragment>,
114    node: &StreamNode,
115    is_local_table_id: bool,
116) -> Result<StreamNode> {
117    let delta_join_node = match &node.node_body {
118        Some(NodeBody::DeltaIndexJoin(node)) => node,
119        _ => unreachable!(),
120    };
121    let output_indices = &delta_join_node.output_indices;
122
123    let arrange_0 = arrange_0_frag.node.as_ref().unwrap();
124    let arrange_1 = arrange_1_frag.node.as_ref().unwrap();
125    let exchange_a0l0 = build_no_shuffle_exchange_for_delta_join(state, arrange_0);
126    let exchange_a0l1 = build_consistent_hash_shuffle_exchange_for_delta_join(
127        state,
128        arrange_0,
129        delta_join_node
130            .left_key
131            .iter()
132            .map(|x| *x as u32)
133            .collect_vec(),
134    );
135    let exchange_a1l0 = build_consistent_hash_shuffle_exchange_for_delta_join(
136        state,
137        arrange_1,
138        delta_join_node
139            .right_key
140            .iter()
141            .map(|x| *x as u32)
142            .collect_vec(),
143    );
144    let exchange_a1l1 = build_no_shuffle_exchange_for_delta_join(state, arrange_1);
145
146    let i0_length = arrange_0.fields.len();
147    let i1_length = arrange_1.fields.len();
148
149    let i0_output_indices = (0..i0_length as u32).collect_vec();
150    let i1_output_indices = (0..i1_length as u32).collect_vec();
151
152    let lookup_0_column_reordering = {
153        let tmp: Vec<i32> = (i1_length..i1_length + i0_length)
154            .chain(0..i1_length)
155            .map(|x| x as _)
156            .collect_vec();
157        output_indices
158            .iter()
159            .map(|&x| tmp[x as usize])
160            .collect_vec()
161    };
162    // lookup left table by right stream
163    let lookup_0 = build_lookup_for_delta_join(
164        state,
165        (&exchange_a1l0, &exchange_a0l0),
166        (node.fields.clone(), node.stream_key.clone()),
167        LookupNode {
168            stream_key: delta_join_node.right_key.clone(),
169            arrange_key: delta_join_node.left_key.clone(),
170            use_current_epoch: false,
171            // will be updated later to a global id
172            arrangement_table_id: if is_local_table_id {
173                Some(ArrangementTableId::TableId(delta_join_node.left_table_id))
174            } else {
175                Some(ArrangementTableId::IndexId(delta_join_node.left_table_id))
176            },
177            column_mapping: lookup_0_column_reordering,
178            arrangement_table_info: delta_join_node.left_info.clone(),
179        },
180    );
181    let lookup_1_column_reordering = {
182        let tmp: Vec<i32> = (0..i0_length + i1_length)
183            .chain(0..i1_length)
184            .map(|x| x as _)
185            .collect_vec();
186        output_indices
187            .iter()
188            .map(|&x| tmp[x as usize])
189            .collect_vec()
190    };
191    // lookup right table by left stream
192    let lookup_1 = build_lookup_for_delta_join(
193        state,
194        (&exchange_a0l1, &exchange_a1l1),
195        (node.fields.clone(), node.stream_key.clone()),
196        LookupNode {
197            stream_key: delta_join_node.left_key.clone(),
198            arrange_key: delta_join_node.right_key.clone(),
199            use_current_epoch: true,
200            // will be updated later to a global id
201            arrangement_table_id: if is_local_table_id {
202                Some(ArrangementTableId::TableId(delta_join_node.right_table_id))
203            } else {
204                Some(ArrangementTableId::IndexId(delta_join_node.right_table_id))
205            },
206            column_mapping: lookup_1_column_reordering,
207            arrangement_table_info: delta_join_node.right_info.clone(),
208        },
209    );
210
211    let lookup_0_frag = build_and_add_fragment(state, lookup_0)?;
212    let lookup_1_frag = build_and_add_fragment(state, lookup_1)?;
213
214    // Place index(arrange) together with corresponding lookup operator, so that we can lookup on
215    // the same node.
216    state.fragment_graph.add_edge(
217        arrange_0_frag.fragment_id,
218        lookup_0_frag.fragment_id,
219        StreamFragmentEdge {
220            dispatch_strategy: dispatch_no_shuffle(i0_output_indices.clone()),
221            link_id: exchange_a0l0.operator_id,
222        },
223    );
224
225    // Use consistent hash shuffle to distribute the index(arrange) to another lookup operator, so
226    // that we can find the correct node to lookup.
227    state.fragment_graph.add_edge(
228        arrange_0_frag.fragment_id,
229        lookup_1_frag.fragment_id,
230        StreamFragmentEdge {
231            dispatch_strategy: dispatch_consistent_hash_shuffle(
232                delta_join_node
233                    .left_key
234                    .iter()
235                    .map(|x| *x as u32)
236                    .collect_vec(),
237                i0_output_indices,
238            ),
239            link_id: exchange_a0l1.operator_id,
240        },
241    );
242
243    // Use consistent hash shuffle to distribute the index(arrange) to another lookup operator, so
244    // that we can find the correct node to lookup.
245    state.fragment_graph.add_edge(
246        arrange_1_frag.fragment_id,
247        lookup_0_frag.fragment_id,
248        StreamFragmentEdge {
249            dispatch_strategy: dispatch_consistent_hash_shuffle(
250                delta_join_node
251                    .right_key
252                    .iter()
253                    .map(|x| *x as u32)
254                    .collect_vec(),
255                i1_output_indices.clone(),
256            ),
257            link_id: exchange_a1l0.operator_id,
258        },
259    );
260
261    // Place index(arrange) together with corresponding lookup operator, so that we can lookup on
262    // the same node.
263    state.fragment_graph.add_edge(
264        arrange_1_frag.fragment_id,
265        lookup_1_frag.fragment_id,
266        StreamFragmentEdge {
267            dispatch_strategy: dispatch_no_shuffle(i1_output_indices),
268            link_id: exchange_a1l1.operator_id,
269        },
270    );
271
272    let exchange_l0m =
273        build_consistent_hash_shuffle_exchange_for_delta_join(state, node, node.stream_key.clone());
274    let exchange_l1m =
275        build_consistent_hash_shuffle_exchange_for_delta_join(state, node, node.stream_key.clone());
276
277    // LookupUnion's inputs might have different distribution and we need to unify them by using
278    // hash shuffle.
279    let union = StreamNode {
280        operator_id: state.gen_operator_id() as u64,
281        identity: "Union".into(),
282        fields: node.fields.clone(),
283        stream_key: node.stream_key.clone(),
284        node_body: Some(NodeBody::LookupUnion(Box::new(LookupUnionNode {
285            order: vec![1, 0],
286        }))),
287        input: vec![exchange_l0m.clone(), exchange_l1m.clone()],
288        append_only: node.append_only,
289    };
290
291    state.fragment_graph.add_edge(
292        lookup_0_frag.fragment_id,
293        current_fragment.fragment_id,
294        StreamFragmentEdge {
295            dispatch_strategy: dispatch_consistent_hash_shuffle(
296                node.stream_key.clone(),
297                (0..node.fields.len() as u32).collect(),
298            ),
299            link_id: exchange_l0m.operator_id,
300        },
301    );
302
303    state.fragment_graph.add_edge(
304        lookup_1_frag.fragment_id,
305        current_fragment.fragment_id,
306        StreamFragmentEdge {
307            dispatch_strategy: dispatch_consistent_hash_shuffle(
308                node.stream_key.clone(),
309                (0..node.fields.len() as u32).collect(),
310            ),
311            link_id: exchange_l1m.operator_id,
312        },
313    );
314
315    Ok(union)
316}
317
318pub(crate) fn build_delta_join_without_arrange(
319    state: &mut BuildFragmentGraphState,
320    current_fragment: &StreamFragment,
321    mut node: StreamNode,
322) -> Result<StreamNode> {
323    match &node.node_body {
324        Some(NodeBody::DeltaIndexJoin(node)) => node,
325        _ => unreachable!(),
326    };
327
328    let [arrange_0, arrange_1]: [_; 2] = std::mem::take(&mut node.input).try_into().unwrap();
329
330    // TODO: when distribution key is added to catalog, chain and delta join won't have any
331    // exchange in-between. Then we can safely remove this function.
332    fn pass_through_exchange(mut node: StreamNode) -> StreamNode {
333        if let Some(NodeBody::Exchange(exchange)) = node.node_body {
334            if let DispatcherType::NoShuffle =
335                exchange.strategy.as_ref().unwrap().get_type().unwrap()
336            {
337                return node.input.remove(0);
338            }
339            panic!("exchange other than no_shuffle not allowed between delta join and arrange");
340        } else {
341            // pass
342            node
343        }
344    }
345
346    let arrange_0 = pass_through_exchange(arrange_0);
347    let arrange_1 = pass_through_exchange(arrange_1);
348
349    let arrange_0_frag = build_and_add_fragment(state, arrange_0)?;
350    let arrange_1_frag = build_and_add_fragment(state, arrange_1)?;
351
352    let union = build_delta_join_inner(
353        state,
354        current_fragment,
355        arrange_0_frag,
356        arrange_1_frag,
357        &node,
358        false,
359    )?;
360
361    Ok(union)
362}