risingwave_frontend/stream_fragmenter/rewrite/
delta_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::rc::Rc;
16
17use itertools::Itertools;
18use risingwave_pb::plan_common::PbField;
19use risingwave_pb::stream_plan::lookup_node::ArrangementTableId;
20use risingwave_pb::stream_plan::stream_node::NodeBody;
21use risingwave_pb::stream_plan::{
22    DispatchStrategy, DispatcherType, ExchangeNode, LookupNode, LookupUnionNode,
23    PbDispatchOutputMapping, StreamNode,
24};
25
26use super::super::{BuildFragmentGraphState, StreamFragment, StreamFragmentEdge};
27use crate::error::ErrorCode::NotSupported;
28use crate::error::Result;
29use crate::stream_fragmenter::build_and_add_fragment;
30
31fn build_no_shuffle_exchange_for_delta_join(
32    state: &mut BuildFragmentGraphState,
33    upstream: &StreamNode,
34) -> StreamNode {
35    StreamNode {
36        operator_id: state.gen_operator_id() as u64,
37        identity: "NO SHUFFLE Exchange (Lookup and Merge)".into(),
38        fields: upstream.fields.clone(),
39        stream_key: upstream.stream_key.clone(),
40        node_body: Some(NodeBody::Exchange(Box::new(ExchangeNode {
41            strategy: Some(dispatch_no_shuffle(
42                (0..(upstream.fields.len() as u32)).collect(),
43            )),
44        }))),
45        input: vec![],
46        stream_kind: upstream.stream_kind,
47    }
48}
49
50fn build_consistent_hash_shuffle_exchange_for_delta_join(
51    state: &mut BuildFragmentGraphState,
52    upstream: &StreamNode,
53    dist_key_indices: Vec<u32>,
54) -> StreamNode {
55    StreamNode {
56        operator_id: state.gen_operator_id() as u64,
57        identity: "HASH Exchange (Lookup and Merge)".into(),
58        fields: upstream.fields.clone(),
59        stream_key: upstream.stream_key.clone(),
60        node_body: Some(NodeBody::Exchange(Box::new(ExchangeNode {
61            strategy: Some(dispatch_consistent_hash_shuffle(
62                dist_key_indices,
63                (0..(upstream.fields.len() as u32)).collect(),
64            )),
65        }))),
66        input: vec![],
67        stream_kind: upstream.stream_kind,
68    }
69}
70
71fn dispatch_no_shuffle(output_indices: Vec<u32>) -> DispatchStrategy {
72    DispatchStrategy {
73        r#type: DispatcherType::NoShuffle.into(),
74        dist_key_indices: vec![],
75        output_mapping: PbDispatchOutputMapping::simple(output_indices).into(),
76    }
77}
78
79fn dispatch_consistent_hash_shuffle(
80    dist_key_indices: Vec<u32>,
81    output_indices: Vec<u32>,
82) -> DispatchStrategy {
83    // Actually Hash shuffle is consistent hash shuffle now.
84    DispatchStrategy {
85        r#type: DispatcherType::Hash.into(),
86        dist_key_indices,
87        output_mapping: PbDispatchOutputMapping::simple(output_indices).into(),
88    }
89}
90
91fn build_lookup_for_delta_join(
92    state: &mut BuildFragmentGraphState,
93    (exchange_node_arrangement, exchange_node_stream): (&StreamNode, &StreamNode),
94    (output_fields, output_stream_key): (Vec<PbField>, Vec<u32>),
95    lookup_node: LookupNode,
96) -> StreamNode {
97    StreamNode {
98        operator_id: state.gen_operator_id() as u64,
99        identity: "Lookup".into(),
100        fields: output_fields,
101        stream_key: output_stream_key,
102        node_body: Some(NodeBody::Lookup(Box::new(lookup_node))),
103        input: vec![
104            exchange_node_arrangement.clone(),
105            exchange_node_stream.clone(),
106        ],
107        stream_kind: exchange_node_stream.stream_kind,
108    }
109}
110
111fn build_delta_join_inner(
112    state: &mut BuildFragmentGraphState,
113    current_fragment: &StreamFragment,
114    arrange_0_frag: Rc<StreamFragment>,
115    arrange_1_frag: Rc<StreamFragment>,
116    node: &StreamNode,
117    is_local_table_id: bool,
118) -> Result<StreamNode> {
119    let delta_join_node = match &node.node_body {
120        Some(NodeBody::DeltaIndexJoin(node)) => node,
121        _ => unreachable!(),
122    };
123    let output_indices = &delta_join_node.output_indices;
124
125    let arrange_0 = arrange_0_frag.node.as_ref().unwrap();
126    let arrange_1 = arrange_1_frag.node.as_ref().unwrap();
127    let exchange_a0l0 = build_no_shuffle_exchange_for_delta_join(state, arrange_0);
128    let exchange_a0l1 = build_consistent_hash_shuffle_exchange_for_delta_join(
129        state,
130        arrange_0,
131        delta_join_node
132            .left_key
133            .iter()
134            .map(|x| *x as u32)
135            .collect_vec(),
136    );
137    let exchange_a1l0 = build_consistent_hash_shuffle_exchange_for_delta_join(
138        state,
139        arrange_1,
140        delta_join_node
141            .right_key
142            .iter()
143            .map(|x| *x as u32)
144            .collect_vec(),
145    );
146    let exchange_a1l1 = build_no_shuffle_exchange_for_delta_join(state, arrange_1);
147
148    let i0_length = arrange_0.fields.len();
149    let i1_length = arrange_1.fields.len();
150
151    let i0_output_indices = (0..i0_length as u32).collect_vec();
152    let i1_output_indices = (0..i1_length as u32).collect_vec();
153
154    let lookup_0_column_reordering = {
155        let tmp: Vec<i32> = (i1_length..i1_length + i0_length)
156            .chain(0..i1_length)
157            .map(|x| x as _)
158            .collect_vec();
159        output_indices
160            .iter()
161            .map(|&x| tmp[x as usize])
162            .collect_vec()
163    };
164    // lookup left table by right stream
165    let lookup_0 = build_lookup_for_delta_join(
166        state,
167        (&exchange_a1l0, &exchange_a0l0),
168        (node.fields.clone(), node.stream_key.clone()),
169        LookupNode {
170            stream_key: delta_join_node.right_key.clone(),
171            arrange_key: delta_join_node.left_key.clone(),
172            use_current_epoch: false,
173            // will be updated later to a global id
174            arrangement_table_id: if is_local_table_id {
175                Some(ArrangementTableId::TableId(
176                    delta_join_node.left_table_id.as_raw_id(),
177                ))
178            } else {
179                Some(ArrangementTableId::IndexId(
180                    delta_join_node.left_table_id.as_raw_id(),
181                ))
182            },
183            column_mapping: lookup_0_column_reordering,
184            arrangement_table_info: delta_join_node.left_info.clone(),
185        },
186    );
187    let lookup_1_column_reordering = {
188        let tmp: Vec<i32> = (0..i0_length + i1_length)
189            .chain(0..i1_length)
190            .map(|x| x as _)
191            .collect_vec();
192        output_indices
193            .iter()
194            .map(|&x| tmp[x as usize])
195            .collect_vec()
196    };
197    // lookup right table by left stream
198    let lookup_1 = build_lookup_for_delta_join(
199        state,
200        (&exchange_a0l1, &exchange_a1l1),
201        (node.fields.clone(), node.stream_key.clone()),
202        LookupNode {
203            stream_key: delta_join_node.left_key.clone(),
204            arrange_key: delta_join_node.right_key.clone(),
205            use_current_epoch: true,
206            // will be updated later to a global id
207            arrangement_table_id: if is_local_table_id {
208                Some(ArrangementTableId::TableId(
209                    delta_join_node.right_table_id.as_raw_id(),
210                ))
211            } else {
212                Some(ArrangementTableId::IndexId(
213                    delta_join_node.right_table_id.as_raw_id(),
214                ))
215            },
216            column_mapping: lookup_1_column_reordering,
217            arrangement_table_info: delta_join_node.right_info.clone(),
218        },
219    );
220
221    let lookup_0_frag = build_and_add_fragment(state, lookup_0)?;
222    let lookup_1_frag = build_and_add_fragment(state, lookup_1)?;
223
224    // Place index(arrange) together with corresponding lookup operator, so that we can lookup on
225    // the same node.
226    state.fragment_graph.add_edge(
227        arrange_0_frag.fragment_id,
228        lookup_0_frag.fragment_id,
229        StreamFragmentEdge {
230            dispatch_strategy: dispatch_no_shuffle(i0_output_indices.clone()),
231            link_id: exchange_a0l0.operator_id,
232        },
233    );
234
235    // Use consistent hash shuffle to distribute the index(arrange) to another lookup operator, so
236    // that we can find the correct node to lookup.
237    state.fragment_graph.add_edge(
238        arrange_0_frag.fragment_id,
239        lookup_1_frag.fragment_id,
240        StreamFragmentEdge {
241            dispatch_strategy: dispatch_consistent_hash_shuffle(
242                delta_join_node
243                    .left_key
244                    .iter()
245                    .map(|x| *x as u32)
246                    .collect_vec(),
247                i0_output_indices,
248            ),
249            link_id: exchange_a0l1.operator_id,
250        },
251    );
252
253    // Use consistent hash shuffle to distribute the index(arrange) to another lookup operator, so
254    // that we can find the correct node to lookup.
255    state.fragment_graph.add_edge(
256        arrange_1_frag.fragment_id,
257        lookup_0_frag.fragment_id,
258        StreamFragmentEdge {
259            dispatch_strategy: dispatch_consistent_hash_shuffle(
260                delta_join_node
261                    .right_key
262                    .iter()
263                    .map(|x| *x as u32)
264                    .collect_vec(),
265                i1_output_indices.clone(),
266            ),
267            link_id: exchange_a1l0.operator_id,
268        },
269    );
270
271    // Place index(arrange) together with corresponding lookup operator, so that we can lookup on
272    // the same node.
273    state.fragment_graph.add_edge(
274        arrange_1_frag.fragment_id,
275        lookup_1_frag.fragment_id,
276        StreamFragmentEdge {
277            dispatch_strategy: dispatch_no_shuffle(i1_output_indices),
278            link_id: exchange_a1l1.operator_id,
279        },
280    );
281
282    let exchange_l0m =
283        build_consistent_hash_shuffle_exchange_for_delta_join(state, node, node.stream_key.clone());
284    let exchange_l1m =
285        build_consistent_hash_shuffle_exchange_for_delta_join(state, node, node.stream_key.clone());
286
287    // LookupUnion's inputs might have different distribution and we need to unify them by using
288    // hash shuffle.
289    let union = StreamNode {
290        operator_id: state.gen_operator_id() as u64,
291        identity: "Union".into(),
292        fields: node.fields.clone(),
293        stream_key: node.stream_key.clone(),
294        node_body: Some(NodeBody::LookupUnion(Box::new(LookupUnionNode {
295            order: vec![1, 0],
296        }))),
297        input: vec![exchange_l0m.clone(), exchange_l1m.clone()],
298        stream_kind: node.stream_kind,
299    };
300
301    state.fragment_graph.add_edge(
302        lookup_0_frag.fragment_id,
303        current_fragment.fragment_id,
304        StreamFragmentEdge {
305            dispatch_strategy: dispatch_consistent_hash_shuffle(
306                node.stream_key.clone(),
307                (0..node.fields.len() as u32).collect(),
308            ),
309            link_id: exchange_l0m.operator_id,
310        },
311    );
312
313    state.fragment_graph.add_edge(
314        lookup_1_frag.fragment_id,
315        current_fragment.fragment_id,
316        StreamFragmentEdge {
317            dispatch_strategy: dispatch_consistent_hash_shuffle(
318                node.stream_key.clone(),
319                (0..node.fields.len() as u32).collect(),
320            ),
321            link_id: exchange_l1m.operator_id,
322        },
323    );
324
325    Ok(union)
326}
327
328pub(crate) fn build_delta_join_without_arrange(
329    state: &mut BuildFragmentGraphState,
330    current_fragment: &StreamFragment,
331    mut node: StreamNode,
332) -> Result<StreamNode> {
333    match &node.node_body {
334        Some(NodeBody::DeltaIndexJoin(node)) => node,
335        _ => unreachable!(),
336    };
337
338    let [arrange_0, arrange_1]: [_; 2] = std::mem::take(&mut node.input).try_into().unwrap();
339
340    // TODO: when distribution key is added to catalog, chain and delta join won't have any
341    // exchange in-between. Then we can safely remove this function.
342    fn pass_through_exchange(mut node: StreamNode) -> StreamNode {
343        if let Some(NodeBody::Exchange(exchange)) = node.node_body {
344            if let DispatcherType::NoShuffle =
345                exchange.strategy.as_ref().unwrap().get_type().unwrap()
346            {
347                return node.input.remove(0);
348            }
349            panic!("exchange other than no_shuffle not allowed between delta join and arrange");
350        } else {
351            // pass
352            node
353        }
354    }
355
356    let arrange_0 = pass_through_exchange(arrange_0);
357    let arrange_1 = pass_through_exchange(arrange_1);
358
359    let arrange_0_frag = build_and_add_fragment(state, arrange_0)?;
360    let arrange_1_frag = build_and_add_fragment(state, arrange_1)?;
361
362    let union = build_delta_join_inner(
363        state,
364        current_fragment,
365        arrange_0_frag,
366        arrange_1_frag,
367        &node,
368        false,
369    )?;
370
371    if state.has_snapshot_backfill {
372        return Err(NotSupported(
373            "Delta join with snapshot backfill is not supported".to_owned(),
374            "Please use a different join strategy or disable snapshot backfill by `SET streaming_use_snapshot_backfill = false`.".to_owned(),
375        ).into());
376    }
377
378    Ok(union)
379}