risingwave_frontend/stream_fragmenter/rewrite/
delta_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::rc::Rc;
16
17use itertools::Itertools;
18use risingwave_pb::plan_common::PbField;
19use risingwave_pb::stream_plan::lookup_node::ArrangementTableId;
20use risingwave_pb::stream_plan::stream_node::NodeBody;
21use risingwave_pb::stream_plan::{
22    DispatchStrategy, DispatcherType, ExchangeNode, LookupNode, LookupUnionNode,
23    PbDispatchOutputMapping, StreamNode,
24};
25
26use super::super::{BuildFragmentGraphState, StreamFragment, StreamFragmentEdge};
27use crate::error::Result;
28use crate::stream_fragmenter::build_and_add_fragment;
29
30fn build_no_shuffle_exchange_for_delta_join(
31    state: &mut BuildFragmentGraphState,
32    upstream: &StreamNode,
33) -> StreamNode {
34    StreamNode {
35        operator_id: state.gen_operator_id() as u64,
36        identity: "NO SHUFFLE Exchange (Lookup and Merge)".into(),
37        fields: upstream.fields.clone(),
38        stream_key: upstream.stream_key.clone(),
39        node_body: Some(NodeBody::Exchange(Box::new(ExchangeNode {
40            strategy: Some(dispatch_no_shuffle(
41                (0..(upstream.fields.len() as u32)).collect(),
42            )),
43        }))),
44        input: vec![],
45        append_only: upstream.append_only,
46    }
47}
48
49fn build_consistent_hash_shuffle_exchange_for_delta_join(
50    state: &mut BuildFragmentGraphState,
51    upstream: &StreamNode,
52    dist_key_indices: Vec<u32>,
53) -> StreamNode {
54    StreamNode {
55        operator_id: state.gen_operator_id() as u64,
56        identity: "HASH Exchange (Lookup and Merge)".into(),
57        fields: upstream.fields.clone(),
58        stream_key: upstream.stream_key.clone(),
59        node_body: Some(NodeBody::Exchange(Box::new(ExchangeNode {
60            strategy: Some(dispatch_consistent_hash_shuffle(
61                dist_key_indices,
62                (0..(upstream.fields.len() as u32)).collect(),
63            )),
64        }))),
65        input: vec![],
66        append_only: upstream.append_only,
67    }
68}
69
70fn dispatch_no_shuffle(output_indices: Vec<u32>) -> DispatchStrategy {
71    DispatchStrategy {
72        r#type: DispatcherType::NoShuffle.into(),
73        dist_key_indices: vec![],
74        output_mapping: PbDispatchOutputMapping::simple(output_indices).into(),
75    }
76}
77
78fn dispatch_consistent_hash_shuffle(
79    dist_key_indices: Vec<u32>,
80    output_indices: Vec<u32>,
81) -> DispatchStrategy {
82    // Actually Hash shuffle is consistent hash shuffle now.
83    DispatchStrategy {
84        r#type: DispatcherType::Hash.into(),
85        dist_key_indices,
86        output_mapping: PbDispatchOutputMapping::simple(output_indices).into(),
87    }
88}
89
90fn build_lookup_for_delta_join(
91    state: &mut BuildFragmentGraphState,
92    (exchange_node_arrangement, exchange_node_stream): (&StreamNode, &StreamNode),
93    (output_fields, output_stream_key): (Vec<PbField>, Vec<u32>),
94    lookup_node: LookupNode,
95) -> StreamNode {
96    StreamNode {
97        operator_id: state.gen_operator_id() as u64,
98        identity: "Lookup".into(),
99        fields: output_fields,
100        stream_key: output_stream_key,
101        node_body: Some(NodeBody::Lookup(Box::new(lookup_node))),
102        input: vec![
103            exchange_node_arrangement.clone(),
104            exchange_node_stream.clone(),
105        ],
106        append_only: exchange_node_stream.append_only,
107    }
108}
109
110fn build_delta_join_inner(
111    state: &mut BuildFragmentGraphState,
112    current_fragment: &StreamFragment,
113    arrange_0_frag: Rc<StreamFragment>,
114    arrange_1_frag: Rc<StreamFragment>,
115    node: &StreamNode,
116    is_local_table_id: bool,
117) -> Result<StreamNode> {
118    let delta_join_node = match &node.node_body {
119        Some(NodeBody::DeltaIndexJoin(node)) => node,
120        _ => unreachable!(),
121    };
122    let output_indices = &delta_join_node.output_indices;
123
124    let arrange_0 = arrange_0_frag.node.as_ref().unwrap();
125    let arrange_1 = arrange_1_frag.node.as_ref().unwrap();
126    let exchange_a0l0 = build_no_shuffle_exchange_for_delta_join(state, arrange_0);
127    let exchange_a0l1 = build_consistent_hash_shuffle_exchange_for_delta_join(
128        state,
129        arrange_0,
130        delta_join_node
131            .left_key
132            .iter()
133            .map(|x| *x as u32)
134            .collect_vec(),
135    );
136    let exchange_a1l0 = build_consistent_hash_shuffle_exchange_for_delta_join(
137        state,
138        arrange_1,
139        delta_join_node
140            .right_key
141            .iter()
142            .map(|x| *x as u32)
143            .collect_vec(),
144    );
145    let exchange_a1l1 = build_no_shuffle_exchange_for_delta_join(state, arrange_1);
146
147    let i0_length = arrange_0.fields.len();
148    let i1_length = arrange_1.fields.len();
149
150    let i0_output_indices = (0..i0_length as u32).collect_vec();
151    let i1_output_indices = (0..i1_length as u32).collect_vec();
152
153    let lookup_0_column_reordering = {
154        let tmp: Vec<i32> = (i1_length..i1_length + i0_length)
155            .chain(0..i1_length)
156            .map(|x| x as _)
157            .collect_vec();
158        output_indices
159            .iter()
160            .map(|&x| tmp[x as usize])
161            .collect_vec()
162    };
163    // lookup left table by right stream
164    let lookup_0 = build_lookup_for_delta_join(
165        state,
166        (&exchange_a1l0, &exchange_a0l0),
167        (node.fields.clone(), node.stream_key.clone()),
168        LookupNode {
169            stream_key: delta_join_node.right_key.clone(),
170            arrange_key: delta_join_node.left_key.clone(),
171            use_current_epoch: false,
172            // will be updated later to a global id
173            arrangement_table_id: if is_local_table_id {
174                Some(ArrangementTableId::TableId(delta_join_node.left_table_id))
175            } else {
176                Some(ArrangementTableId::IndexId(delta_join_node.left_table_id))
177            },
178            column_mapping: lookup_0_column_reordering,
179            arrangement_table_info: delta_join_node.left_info.clone(),
180        },
181    );
182    let lookup_1_column_reordering = {
183        let tmp: Vec<i32> = (0..i0_length + i1_length)
184            .chain(0..i1_length)
185            .map(|x| x as _)
186            .collect_vec();
187        output_indices
188            .iter()
189            .map(|&x| tmp[x as usize])
190            .collect_vec()
191    };
192    // lookup right table by left stream
193    let lookup_1 = build_lookup_for_delta_join(
194        state,
195        (&exchange_a0l1, &exchange_a1l1),
196        (node.fields.clone(), node.stream_key.clone()),
197        LookupNode {
198            stream_key: delta_join_node.left_key.clone(),
199            arrange_key: delta_join_node.right_key.clone(),
200            use_current_epoch: true,
201            // will be updated later to a global id
202            arrangement_table_id: if is_local_table_id {
203                Some(ArrangementTableId::TableId(delta_join_node.right_table_id))
204            } else {
205                Some(ArrangementTableId::IndexId(delta_join_node.right_table_id))
206            },
207            column_mapping: lookup_1_column_reordering,
208            arrangement_table_info: delta_join_node.right_info.clone(),
209        },
210    );
211
212    let lookup_0_frag = build_and_add_fragment(state, lookup_0)?;
213    let lookup_1_frag = build_and_add_fragment(state, lookup_1)?;
214
215    // Place index(arrange) together with corresponding lookup operator, so that we can lookup on
216    // the same node.
217    state.fragment_graph.add_edge(
218        arrange_0_frag.fragment_id,
219        lookup_0_frag.fragment_id,
220        StreamFragmentEdge {
221            dispatch_strategy: dispatch_no_shuffle(i0_output_indices.clone()),
222            link_id: exchange_a0l0.operator_id,
223        },
224    );
225
226    // Use consistent hash shuffle to distribute the index(arrange) to another lookup operator, so
227    // that we can find the correct node to lookup.
228    state.fragment_graph.add_edge(
229        arrange_0_frag.fragment_id,
230        lookup_1_frag.fragment_id,
231        StreamFragmentEdge {
232            dispatch_strategy: dispatch_consistent_hash_shuffle(
233                delta_join_node
234                    .left_key
235                    .iter()
236                    .map(|x| *x as u32)
237                    .collect_vec(),
238                i0_output_indices,
239            ),
240            link_id: exchange_a0l1.operator_id,
241        },
242    );
243
244    // Use consistent hash shuffle to distribute the index(arrange) to another lookup operator, so
245    // that we can find the correct node to lookup.
246    state.fragment_graph.add_edge(
247        arrange_1_frag.fragment_id,
248        lookup_0_frag.fragment_id,
249        StreamFragmentEdge {
250            dispatch_strategy: dispatch_consistent_hash_shuffle(
251                delta_join_node
252                    .right_key
253                    .iter()
254                    .map(|x| *x as u32)
255                    .collect_vec(),
256                i1_output_indices.clone(),
257            ),
258            link_id: exchange_a1l0.operator_id,
259        },
260    );
261
262    // Place index(arrange) together with corresponding lookup operator, so that we can lookup on
263    // the same node.
264    state.fragment_graph.add_edge(
265        arrange_1_frag.fragment_id,
266        lookup_1_frag.fragment_id,
267        StreamFragmentEdge {
268            dispatch_strategy: dispatch_no_shuffle(i1_output_indices),
269            link_id: exchange_a1l1.operator_id,
270        },
271    );
272
273    let exchange_l0m =
274        build_consistent_hash_shuffle_exchange_for_delta_join(state, node, node.stream_key.clone());
275    let exchange_l1m =
276        build_consistent_hash_shuffle_exchange_for_delta_join(state, node, node.stream_key.clone());
277
278    // LookupUnion's inputs might have different distribution and we need to unify them by using
279    // hash shuffle.
280    let union = StreamNode {
281        operator_id: state.gen_operator_id() as u64,
282        identity: "Union".into(),
283        fields: node.fields.clone(),
284        stream_key: node.stream_key.clone(),
285        node_body: Some(NodeBody::LookupUnion(Box::new(LookupUnionNode {
286            order: vec![1, 0],
287        }))),
288        input: vec![exchange_l0m.clone(), exchange_l1m.clone()],
289        append_only: node.append_only,
290    };
291
292    state.fragment_graph.add_edge(
293        lookup_0_frag.fragment_id,
294        current_fragment.fragment_id,
295        StreamFragmentEdge {
296            dispatch_strategy: dispatch_consistent_hash_shuffle(
297                node.stream_key.clone(),
298                (0..node.fields.len() as u32).collect(),
299            ),
300            link_id: exchange_l0m.operator_id,
301        },
302    );
303
304    state.fragment_graph.add_edge(
305        lookup_1_frag.fragment_id,
306        current_fragment.fragment_id,
307        StreamFragmentEdge {
308            dispatch_strategy: dispatch_consistent_hash_shuffle(
309                node.stream_key.clone(),
310                (0..node.fields.len() as u32).collect(),
311            ),
312            link_id: exchange_l1m.operator_id,
313        },
314    );
315
316    Ok(union)
317}
318
319pub(crate) fn build_delta_join_without_arrange(
320    state: &mut BuildFragmentGraphState,
321    current_fragment: &StreamFragment,
322    mut node: StreamNode,
323) -> Result<StreamNode> {
324    match &node.node_body {
325        Some(NodeBody::DeltaIndexJoin(node)) => node,
326        _ => unreachable!(),
327    };
328
329    let [arrange_0, arrange_1]: [_; 2] = std::mem::take(&mut node.input).try_into().unwrap();
330
331    // TODO: when distribution key is added to catalog, chain and delta join won't have any
332    // exchange in-between. Then we can safely remove this function.
333    fn pass_through_exchange(mut node: StreamNode) -> StreamNode {
334        if let Some(NodeBody::Exchange(exchange)) = node.node_body {
335            if let DispatcherType::NoShuffle =
336                exchange.strategy.as_ref().unwrap().get_type().unwrap()
337            {
338                return node.input.remove(0);
339            }
340            panic!("exchange other than no_shuffle not allowed between delta join and arrange");
341        } else {
342            // pass
343            node
344        }
345    }
346
347    let arrange_0 = pass_through_exchange(arrange_0);
348    let arrange_1 = pass_through_exchange(arrange_1);
349
350    let arrange_0_frag = build_and_add_fragment(state, arrange_0)?;
351    let arrange_1_frag = build_and_add_fragment(state, arrange_1)?;
352
353    let union = build_delta_join_inner(
354        state,
355        current_fragment,
356        arrange_0_frag,
357        arrange_1_frag,
358        &node,
359        false,
360    )?;
361
362    Ok(union)
363}