risingwave_frontend/optimizer/
delta_join_solver.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The solver for delta join, which determines lookup order of a join plan.
16//! All collection types in this module should be `BTree` to ensure determinism between runs.
17//!
18//! # Representation of Multi-Way Join
19//!
20//! In this module, `a*` means lookup executor that looks-up the arrangement `a` of the current
21//! epoch. `a` means looks-up arrangement `a` of the previous epoch.
22//!
23//! Delta joins only support inner equal join. The solver is based on the following formula (take
24//! 3-way join as an example):
25//!
26//! ```plain
27//! d((A join1 B) join2 C)
28//! = ((A + dA) join1 (B + dB)) join2 (C + dC) - A join1 B join2 C
29//! = (A join1 B + A join1 dB + dA join1 (B + dB)) join2 (C + dC) - A join1 B join2 C
30//! = A join1 B join2 (C + dC) + A join1 dB join2 (C + dC) + dA join1 (B + dB) join2 (C + dC) - A join1 B join2 C
31//! = A join1 B join2 dC + A join1 dB join2 (C + dC) + dA join1 (B + dB) join2 (C + dC)
32//! = dA join1 (B + dB) join2 (C + dC) + dB join1 A join2 (C + dC) + dC join2 B join1 A
33//!
34//! join1 means A join B using condition #1,
35//! join2 means B join C using condition #2.
36//! ```
37//!
38//! Inner joins satisfy commutative law and associative laws, so we can switch them back and forth
39//! between joins.
40//!
41//! ... which generates the following look-up graph:
42//!
43//! ```plain
44//! a -> b* -> c* -> output3
45//! b -> a  -> c* -> output2
46//! c -> b  -> a  -> output1
47//! ```
48//!
49//! And the final output is `output1 <concat> output2 <concat> output3`. The concatenation is
50//! ordered: all items from output 1 must appear before output 2.
51//!
52//! TODO: support dynamic filter and filter condition in lookups.
53//!
54//! # What need the caller do?
55//!
56//! After solving the delta join plan, caller will need to do a lot of things.
57//!
58//! * Use the correct index for every stream input. By looking at the first lookup fragment in each
59//!   row, we can decide whether to use `a.x` or `a.y` as index for stream input.
60//! * Insert exchanges between lookups of different distribution. Generally, if the whole row is
61//!   operating on the same join key, we only need to do 1v1 exchanges between lookups. However, it
62//!   would be possible that a row of lookup first join `a.x == b.x`, then `a.y == c.y`. In this
63//!   case, we will need to insert hash exchange between these two lookups.
64//! * Ensure the order of union. Always union from the last row to the first row.
65//! * Insert exchange before union. Still the case for `a.x == b.x`, then `a.y == c.y`, it is
66//!   possible that every lookup path produces different distribution. We need to shuffle them
67//!   before feeding data to union.
68
69// FIXME: https://github.com/rust-lang/rust-analyzer/issues/17685
70#![allow(dead_code)]
71
72use std::collections::{BTreeMap, BTreeSet};
73
74use anyhow::{Result, anyhow};
75use itertools::Itertools;
76
77#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord)]
78pub struct JoinTable(pub usize);
79
80/// Represents whether `left` and `right` can be joined using a condition.
81#[derive(Debug, Clone, Eq, PartialEq)]
82pub struct JoinEdge {
83    pub left: JoinTable,
84    pub right: JoinTable,
85    pub left_join_key: Vec<usize>,
86    pub right_join_key: Vec<usize>,
87}
88
89impl JoinEdge {
90    /// Reverse the order of the edge.
91    pub fn reverse(&self) -> Self {
92        Self {
93            left: self.right,
94            right: self.left,
95            left_join_key: self.right_join_key.clone(),
96            right_join_key: self.left_join_key.clone(),
97        }
98    }
99}
100
101/// Decides how to place arrangements over lookup nodes. Given a 3-way join example:
102///
103/// ```plain
104/// a -> 1* -> 2* ->
105/// b -> 3  -> 4* ->
106/// c -> 5  -> 6  ->
107/// ```
108///
109/// If user provides the multi-way join order of `(a join b) join c`, and set the strategy to be
110/// [`ArrangeStrategy::LeftFirst`], and if three tables are of the same join key (the graph is
111/// fully-connected and no shuffle needed), then we will place the arrangements over the lookup
112/// nodes in the following way:
113///
114/// ```plain
115/// a -> b* -> c* ->
116/// b -> a  -> c* ->
117/// c -> a  -> b  ->
118/// ```
119///
120/// The left side of joins will be preferred for the left lookups when selecting arrangements.
121///
122/// If strategy is set to [`ArrangeStrategy::RightFirst`],
123///
124/// ```plain
125/// a -> c* -> b* ->
126/// b -> c  -> a* ->
127/// c -> b  -> a  ->
128/// ```
129///
130/// The right side of joins will be preferred for the left lookups when selecting arrangements.
131#[derive(Clone, Debug, PartialEq, Eq)]
132pub enum ArrangeStrategy {
133    /// The left-most table will be preferred to be the first lookup table.
134    LeftFirst,
135    /// The right-most table will be preferred to be the first lookup table.
136    RightFirst,
137}
138
139/// Decides how to place stream inputs. Given a 3-way join example:
140///
141/// ```plain
142/// x -> 1* -> 2* ->
143/// x -> 3  -> 4* ->
144/// x -> 5  -> 6  ->
145/// ```
146///
147/// ... where `n*` means lookup this epoch. If user provides the multi-way join order of `(a join b)
148/// join c`, and set the strategy to be [`StreamStrategy::LeftThisEpoch`], then we will place the
149/// stream side as:
150///
151/// ```plain
152/// a -> 1* -> 2* ->
153/// b -> 3  -> 4* ->
154/// c -> 5  -> 6  ->
155/// ```
156///
157/// ... where `a`, the left-most table in multi-way join, passes through most number of
158/// lookup-this-epoch.
159///
160/// If strategy is set to [`StreamStrategy::RightThisEpoch`],
161///
162/// ```plain
163/// c -> 1* -> 2* ->
164/// b -> 3  -> 4* ->
165/// a -> 5  -> 6  ->
166/// ```
167///
168/// ... where `c`, the right-most table in multi-way join, passes through most number of
169/// lookup-this-epoch.
170///
171/// More lookup-this-epochs in a lookup row mean more latency when a barrier is flushed. If lookup
172/// executor is set to lookup this epoch, it will wait until a barrier from the arrangement side
173/// before actually starting lookups and joins.
174#[derive(Clone, Debug, PartialEq, Eq)]
175pub enum StreamStrategy {
176    /// The left-most table will be preferred to be the stream to lookup this epoch.
177    LeftThisEpoch,
178    /// The right-most table will be preferred to be the stream to lookup this epoch.
179    RightThisEpoch,
180}
181
182/// Given a multi-way inner join plan, the solver will produces the most efficient way of getting
183/// those join done.
184///
185/// ## Input
186///
187/// [`DeltaJoinSolver`] needs the information of the cost of joining two tables. You'll need to
188/// provide edges of the join graph. Each [`JoinEdge`] `(a, b) -> cost` describes the cost of inner
189/// join a and b.
190///
191/// ## Output
192///
193/// The `solve` function will return a vector of [`LookupPath`]. See [`LookupPath`] for more.
194pub struct DeltaJoinSolver {
195    /// Strategy of arrangement placement, see docs of [`ArrangeStrategy`] for more details.
196    arrange_strategy: ArrangeStrategy,
197
198    /// Strategy of stream placement, see docs of [`StreamStrategy`] for more details.
199    stream_strategy: StreamStrategy,
200
201    /// Possible combination of joins. If `(a, b)` is in `edges`, then `a` and `b` can be joined.
202    /// The edge is bi-directional, so only one pair (either `a, b` or `b, a`) needs to be present
203    /// in this vector.
204    edges: Vec<JoinEdge>,
205
206    /// The recommended join order from optimizer in a multi-way join plan node.
207    join_order: Vec<JoinTable>,
208}
209
210/// A row in the lookup plan, which includes the stream side arrangement, and all arrangements to be
211/// placed in the lookup node.
212///
213/// For example, if we have `LookupPath(a, vec![b, c])`, then we have a row:
214///
215/// ```plain
216/// a -> b -> c
217/// ```
218///
219/// In `Vec<LookupPath>`, the first row always contains the most lookup-this-epochs. For example, if
220/// we have:
221///
222/// ```plain
223/// vec![
224///   LookupPath(a, vec![b, c]),
225///   LookupPath(b, vec![c, a]),
226///   LookupPath(c, vec![a, b])
227/// ]
228/// ```
229///
230/// Then it means...
231///
232/// ```plain
233/// a -> b* -> c* ->
234/// b -> c  -> a* ->
235/// c -> a  -> b  ->
236/// ```
237#[derive(Clone, Debug, Eq, PartialEq)]
238pub struct LookupPath(JoinTable, pub Vec<JoinTable>);
239
240struct SolverEnv {
241    /// Stores all join edges in map. [`JoinEdge`]'s left side is always the key in map.
242    join_edge: BTreeMap<JoinTable, Vec<JoinEdge>>,
243
244    /// Placement order of arrangements
245    arrange_placement_order: Vec<JoinTable>,
246
247    /// Placement order of streams
248    stream_placement_order: Vec<JoinTable>,
249}
250
251impl SolverEnv {
252    fn build_from(solver: &DeltaJoinSolver) -> Self {
253        let mut join_edge = BTreeMap::new();
254
255        for table in &solver.join_order {
256            join_edge.insert(*table, vec![]);
257        }
258        for edge in &solver.edges {
259            join_edge.get_mut(&edge.left).unwrap().push(edge.clone());
260            join_edge.get_mut(&edge.right).unwrap().push(edge.reverse());
261        }
262
263        Self {
264            arrange_placement_order: match solver.arrange_strategy {
265                ArrangeStrategy::LeftFirst => solver.join_order.clone(),
266                ArrangeStrategy::RightFirst => {
267                    solver.join_order.iter().copied().rev().collect_vec()
268                }
269            },
270            stream_placement_order: match solver.stream_strategy {
271                StreamStrategy::LeftThisEpoch => solver.join_order.clone(),
272                StreamStrategy::RightThisEpoch => {
273                    solver.join_order.iter().copied().rev().collect_vec()
274                }
275            },
276            join_edge,
277        }
278    }
279}
280
281impl DeltaJoinSolver {
282    /// Generate a lookup path using the user provided strategy. The lookup path is generated in the
283    /// following way:
284    ///
285    /// * Firstly, we find all tables that can be joined with the current join table set. The tables
286    ///   should have the same join key as the current join set.
287    /// * If there are multiple tables that satisfy this condition, we will pick tables according to
288    ///   [`ArrangeStrategy`].
289    /// * If not, then we have to do a shuffle. We pick all tables that have join condition with the
290    ///   current table set.
291    /// * And then pick a table using [`ArrangeStrategy`].
292    ///
293    /// Basically, this algorithm will greedily pick all tables with the same join key first (so
294    /// that there won't be shuffle). Then, switch a distribution, and do this process again.
295    fn find_lookup_path(
296        &self,
297        solver_env: &SolverEnv,
298        input_stream: JoinTable,
299    ) -> Result<LookupPath> {
300        // The table available to query
301        let mut current_table_set = BTreeSet::new();
302        current_table_set.insert(input_stream);
303
304        // The distribution of the current lookup.
305        let mut current_distribution = vec![];
306
307        // The final lookup path.
308        let mut path = vec![];
309
310        fn satisfies_distribution(
311            current_distribution: &[(JoinTable, Vec<usize>)],
312            edge: &JoinEdge,
313        ) -> bool {
314            // TODO: if `current_distribution` is a `BTreeSet`, we can know if the distribution in
315            // O(1). But as we generally have few tables, so we do a linear scan on tables.
316            if current_distribution.is_empty() {
317                return true;
318            }
319
320            for (table, join_key) in current_distribution {
321                if table == &edge.left && join_key == &edge.left_join_key {
322                    return true;
323                }
324            }
325
326            false
327        }
328
329        'next_table: loop {
330            assert!(
331                path.len() < self.join_order.len(),
332                "internal error: infinite loop"
333            );
334
335            // step 1: find tables that can be joined with `current_table_set` and satisfy the
336            // current distribution.
337            let mut reachable_tables = BTreeMap::new();
338
339            for current_table in &current_table_set {
340                for edge in &solver_env.join_edge[current_table] {
341                    if !current_table_set.contains(&edge.right)
342                        && satisfies_distribution(&current_distribution, edge)
343                    {
344                        reachable_tables.insert(edge.right, edge.clone());
345                    }
346                }
347            }
348
349            // step 2: place arrangements according to the arrange strategy, update current
350            // distribution and current table set.
351            for table in &solver_env.arrange_placement_order {
352                if let Some(edge) = reachable_tables.get(table) {
353                    current_table_set.insert(edge.right);
354                    path.push(edge.right);
355                    current_distribution.push((edge.right, edge.right_join_key.clone()));
356                    continue 'next_table;
357                }
358            }
359
360            // no table can be joined using the same distribution at this point.
361
362            // step 3: find all tables that can be joined with `current_table_set`, regardless of
363            // their distribution.
364            let mut reachable_tables = BTreeMap::new();
365
366            for current_table in &current_table_set {
367                for edge in &solver_env.join_edge[current_table] {
368                    if !current_table_set.contains(&edge.right) {
369                        reachable_tables.insert(edge.right, edge.clone());
370                    }
371                }
372            }
373
374            // step 4: place arrangements according to the arrange strategy, update current
375            // distribution and current table set.
376            for table in &solver_env.arrange_placement_order {
377                if let Some(edge) = reachable_tables.get(table) {
378                    current_table_set.insert(edge.right);
379                    path.push(edge.right);
380                    current_distribution.clear();
381                    current_distribution.push((edge.right, edge.right_join_key.clone()));
382                    continue 'next_table;
383                }
384            }
385
386            // step 5: no tables can be joined any more, what happened?
387            if self.join_order.len() - 1 == path.len() {
388                break;
389            } else {
390                // no table can be joined, while path is still incomplete
391                return Err(anyhow!(
392                    "join plan cannot be generated, tables not connected."
393                ));
394            }
395        }
396
397        Ok(LookupPath(input_stream, path))
398    }
399
400    pub fn solve(&self) -> Result<Vec<LookupPath>> {
401        let solver_env = SolverEnv::build_from(self);
402
403        solver_env
404            .stream_placement_order
405            .iter()
406            .map(|x| self.find_lookup_path(&solver_env, *x))
407            .collect()
408    }
409
410    pub fn new(
411        stream_strategy: StreamStrategy,
412        edges: Vec<JoinEdge>,
413        join_order: Vec<JoinTable>,
414    ) -> Self {
415        Self {
416            stream_strategy,
417            edges,
418            join_order,
419            arrange_strategy: ArrangeStrategy::LeftFirst,
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_2way_join() {
430        let solver = DeltaJoinSolver {
431            arrange_strategy: ArrangeStrategy::LeftFirst,
432            stream_strategy: StreamStrategy::LeftThisEpoch,
433            edges: vec![JoinEdge {
434                left: JoinTable(1),
435                right: JoinTable(2),
436                left_join_key: vec![2, 3],
437                right_join_key: vec![3, 2],
438            }],
439            join_order: vec![JoinTable(1), JoinTable(2)],
440        };
441
442        let result = solver.solve().unwrap();
443
444        assert_eq!(
445            result,
446            vec![
447                LookupPath(JoinTable(1), vec![JoinTable(2)]),
448                LookupPath(JoinTable(2), vec![JoinTable(1)]),
449            ]
450        );
451
452        let solver = DeltaJoinSolver {
453            arrange_strategy: ArrangeStrategy::RightFirst,
454            stream_strategy: StreamStrategy::LeftThisEpoch,
455            edges: vec![JoinEdge {
456                left: JoinTable(1),
457                right: JoinTable(2),
458                left_join_key: vec![2, 3],
459                right_join_key: vec![3, 2],
460            }],
461            join_order: vec![JoinTable(1), JoinTable(2)],
462        };
463
464        let result = solver.solve().unwrap();
465
466        assert_eq!(
467            result,
468            vec![
469                LookupPath(JoinTable(1), vec![JoinTable(2)]),
470                LookupPath(JoinTable(2), vec![JoinTable(1)]),
471            ]
472        );
473
474        let solver = DeltaJoinSolver {
475            arrange_strategy: ArrangeStrategy::LeftFirst,
476            stream_strategy: StreamStrategy::RightThisEpoch,
477            edges: vec![JoinEdge {
478                left: JoinTable(1),
479                right: JoinTable(2),
480                left_join_key: vec![2, 3],
481                right_join_key: vec![3, 2],
482            }],
483            join_order: vec![JoinTable(1), JoinTable(2)],
484        };
485
486        let result = solver.solve().unwrap();
487
488        assert_eq!(
489            result,
490            vec![
491                LookupPath(JoinTable(2), vec![JoinTable(1)]),
492                LookupPath(JoinTable(1), vec![JoinTable(2)]),
493            ]
494        );
495
496        let solver = DeltaJoinSolver {
497            arrange_strategy: ArrangeStrategy::RightFirst,
498            stream_strategy: StreamStrategy::RightThisEpoch,
499            edges: vec![JoinEdge {
500                left: JoinTable(1),
501                right: JoinTable(2),
502                left_join_key: vec![2, 3],
503                right_join_key: vec![3, 2],
504            }],
505            join_order: vec![JoinTable(1), JoinTable(2)],
506        };
507
508        let result = solver.solve().unwrap();
509
510        assert_eq!(
511            result,
512            vec![
513                LookupPath(JoinTable(2), vec![JoinTable(1)]),
514                LookupPath(JoinTable(1), vec![JoinTable(2)]),
515            ]
516        );
517    }
518
519    #[test]
520    fn test_3way_join_one_key() {
521        // Table 1: [2, 3] (composite key x)
522        // Table 2: [3, 2] (composite key x)
523        // Table 3: [1, 1] (composite key x)
524        // t1.x == t2.x == t3.x
525
526        let solver = DeltaJoinSolver {
527            arrange_strategy: ArrangeStrategy::LeftFirst,
528            stream_strategy: StreamStrategy::LeftThisEpoch,
529            edges: vec![
530                JoinEdge {
531                    left: JoinTable(1),
532                    right: JoinTable(2),
533                    left_join_key: vec![2, 3],
534                    right_join_key: vec![3, 2],
535                },
536                JoinEdge {
537                    left: JoinTable(2),
538                    right: JoinTable(3),
539                    left_join_key: vec![3, 2],
540                    right_join_key: vec![1, 1],
541                },
542                JoinEdge {
543                    left: JoinTable(3),
544                    right: JoinTable(1),
545                    left_join_key: vec![1, 1],
546                    right_join_key: vec![2, 3],
547                },
548            ],
549            join_order: vec![JoinTable(1), JoinTable(2), JoinTable(3)],
550        };
551
552        let result = solver.solve().unwrap();
553
554        assert_eq!(
555            result,
556            vec![
557                LookupPath(JoinTable(1), vec![JoinTable(2), JoinTable(3)]),
558                LookupPath(JoinTable(2), vec![JoinTable(1), JoinTable(3)]),
559                LookupPath(JoinTable(3), vec![JoinTable(1), JoinTable(2)])
560            ]
561        );
562    }
563
564    #[test]
565    fn test_3way_join_two_keys() {
566        // Table 1: [0] (key x)
567        // Table 2: [0] (key x), [1] (key y)
568        // Table 3: [1] (key y)
569        // t1.x == t2.x and t2.y == t3.y
570        //
571        // t1 cannot directly join with t3 in this case, as they don't share the same key
572
573        let solver = DeltaJoinSolver {
574            arrange_strategy: ArrangeStrategy::LeftFirst,
575            stream_strategy: StreamStrategy::LeftThisEpoch,
576            edges: vec![
577                JoinEdge {
578                    left: JoinTable(1),
579                    right: JoinTable(2),
580                    left_join_key: vec![0],
581                    right_join_key: vec![0],
582                },
583                JoinEdge {
584                    left: JoinTable(2),
585                    right: JoinTable(3),
586                    left_join_key: vec![1],
587                    right_join_key: vec![1],
588                },
589            ],
590            join_order: vec![JoinTable(1), JoinTable(2), JoinTable(3)],
591        };
592
593        let result = solver.solve().unwrap();
594
595        assert_eq!(
596            result,
597            vec![
598                LookupPath(JoinTable(1), vec![JoinTable(2), JoinTable(3)]),
599                LookupPath(JoinTable(2), vec![JoinTable(1), JoinTable(3)]),
600                LookupPath(JoinTable(3), vec![JoinTable(2), JoinTable(1)])
601            ]
602        );
603    }
604
605    #[test]
606    fn test_invalid_plan() {
607        let solver = DeltaJoinSolver {
608            arrange_strategy: ArrangeStrategy::LeftFirst,
609            stream_strategy: StreamStrategy::LeftThisEpoch,
610            edges: vec![JoinEdge {
611                left: JoinTable(1),
612                right: JoinTable(2),
613                left_join_key: vec![0],
614                right_join_key: vec![0],
615            }],
616            join_order: vec![JoinTable(1), JoinTable(2), JoinTable(3)],
617        };
618
619        let result = solver.solve();
620
621        assert!(result.is_err());
622    }
623}