risingwave_meta/stream/stream_graph/
state_match.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains the logic for matching the internal state tables of two streaming jobs,
16//! used for replacing a streaming job (typically `ALTER MV`) while preserving the existing state.
17
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::hash::{DefaultHasher, Hash as _, Hasher as _};
20
21use itertools::Itertools;
22use risingwave_common::catalog::TableDesc;
23use risingwave_common::util::stream_graph_visitor::visit_stream_node_tables_inner;
24use risingwave_pb::catalog::PbTable;
25use risingwave_pb::stream_plan::StreamNode;
26use strum::IntoDiscriminant;
27
28use crate::model::StreamJobFragments;
29use crate::stream::StreamFragmentGraph;
30
31/// Helper type for describing a [`StreamNode`] in error messages.
32pub(crate) struct StreamNodeDesc(Box<str>);
33
34impl From<&StreamNode> for StreamNodeDesc {
35    fn from(node: &StreamNode) -> Self {
36        let id = node.operator_id;
37        let identity = &node.identity;
38        let body = node.node_body.as_ref().unwrap();
39
40        Self(format!("{}({}, {})", body, id, identity).into_boxed_str())
41    }
42}
43
44impl std::fmt::Display for StreamNodeDesc {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{}", self.0)
47    }
48}
49
50/// Error type for failed state table matching.
51#[derive(thiserror::Error, thiserror_ext::Macro, thiserror_ext::ReportDebug)]
52pub(crate) enum Error {
53    #[error("failed to match graph: {message}")]
54    Graph { message: String },
55
56    #[error("failed to match fragment {id}: {message}")]
57    Fragment {
58        source: Option<Box<Error>>,
59        id: Id,
60        message: String,
61    },
62
63    #[error("failed to match operator {from} to {to}: {message}")]
64    Operator {
65        from: StreamNodeDesc,
66        to: StreamNodeDesc,
67        message: String,
68    },
69}
70
71type Result<T, E = Error> = std::result::Result<T, E>;
72
73/// Fragment id.
74type Id = u32;
75
76/// Node for a fragment in the [`Graph`].
77struct Fragment {
78    /// The fragment id.
79    id: Id,
80    /// The root node of the fragment.
81    root: StreamNode,
82}
83
84/// A streaming job graph that's used for state table matching.
85pub struct Graph {
86    /// All fragments in the graph.
87    nodes: HashMap<Id, Fragment>,
88    /// Downstreams of each fragment.
89    downstreams: HashMap<Id, Vec<Id>>,
90    /// Upstreams of each fragment.
91    upstreams: HashMap<Id, Vec<Id>>,
92}
93
94impl Graph {
95    /// Returns the number of fragments in the graph.
96    fn len(&self) -> usize {
97        self.nodes.len()
98    }
99
100    /// Returns the downstreams of a fragment.
101    fn downstreams(&self, id: Id) -> &[Id] {
102        self.downstreams.get(&id).map_or(&[], |v| v.as_slice())
103    }
104
105    /// Returns the upstreams of a fragment.
106    fn upstreams(&self, id: Id) -> &[Id] {
107        self.upstreams.get(&id).map_or(&[], |v| v.as_slice())
108    }
109
110    /// Returns the topological order of the graph.
111    fn topo_order(&self) -> Result<Vec<Id>> {
112        let mut topo = Vec::new();
113        let mut downstream_cnts = HashMap::new();
114
115        // Iterate all nodes to find the root and initialize the downstream counts.
116        for node_id in self.nodes.keys() {
117            let downstream_cnt = self.downstreams(*node_id).len();
118            if downstream_cnt == 0 {
119                topo.push(*node_id);
120            } else {
121                downstream_cnts.insert(*node_id, downstream_cnt);
122            }
123        }
124
125        let mut i = 0;
126        while let Some(&node_id) = topo.get(i) {
127            i += 1;
128            // Find if we can process more nodes.
129            for &upstream_id in self.upstreams(node_id) {
130                let downstream_cnt = downstream_cnts.get_mut(&upstream_id).unwrap();
131                *downstream_cnt -= 1;
132                if *downstream_cnt == 0 {
133                    downstream_cnts.remove(&upstream_id);
134                    topo.push(upstream_id);
135                }
136            }
137        }
138
139        if !downstream_cnts.is_empty() {
140            // There are nodes that are not processed yet.
141            bail_graph!("fragment graph is not a DAG");
142        }
143        assert_eq!(topo.len(), self.len());
144
145        Ok(topo)
146    }
147
148    /// Calculates the fingerprints of all fragments based on their position in the graph.
149    ///
150    /// This is used to locate the candidates when matching a fragment against another graph.
151    fn fingerprints(&self) -> Result<HashMap<Id, u64>> {
152        let mut fps = HashMap::new();
153
154        let order = self.topo_order()?;
155        for u in order.into_iter().rev() {
156            let upstream_fps = self
157                .upstreams(u)
158                .iter()
159                .map(|id| *fps.get(id).unwrap())
160                .sorted() // allow input order to be arbitrary
161                .collect_vec();
162
163            // Hash the downstream count, upstream count, and upstream fingerprints to
164            // generate the fingerprint of this node.
165            let mut hasher = DefaultHasher::new();
166            (
167                self.upstreams(u).len(),
168                self.downstreams(u).len(),
169                upstream_fps,
170            )
171                .hash(&mut hasher);
172            let fingerprint = hasher.finish();
173
174            fps.insert(u, fingerprint);
175        }
176
177        Ok(fps)
178    }
179}
180
181/// The match result of a fragment in the source graph to a fragment in the target graph.
182struct Match {
183    /// The target fragment id.
184    target: Id,
185    /// The mapping from source table id to target table id within the fragment.
186    table_matches: HashMap<u32, u32>,
187}
188
189/// The successful matching result of two [`Graph`]s.
190struct Matches {
191    /// The mapping from source fragment id to target fragment id.
192    inner: HashMap<Id, Match>,
193    /// The set of target fragment ids that are already matched.
194    matched_targets: HashSet<Id>,
195}
196
197impl Matches {
198    /// Creates a new empty match result.
199    fn new() -> Self {
200        Self {
201            inner: HashMap::new(),
202            matched_targets: HashSet::new(),
203        }
204    }
205
206    /// Returns the target fragment id of a source fragment id.
207    fn target(&self, u: Id) -> Option<Id> {
208        self.inner.get(&u).map(|m| m.target)
209    }
210
211    /// Returns the number of matched fragments.
212    fn len(&self) -> usize {
213        self.inner.len()
214    }
215
216    /// Returns true if the source fragment is already matched.
217    fn matched(&self, u: Id) -> bool {
218        self.inner.contains_key(&u)
219    }
220
221    /// Returns true if the target fragment is already matched.
222    fn target_matched(&self, v: Id) -> bool {
223        self.matched_targets.contains(&v)
224    }
225
226    /// Tries to match a source fragment to a target fragment. If successful, they will be recorded
227    /// in the match result.
228    ///
229    /// This will check the operators and internal tables of the fragment.
230    fn try_match(&mut self, u: &Fragment, v: &Fragment) -> Result<()> {
231        if self.matched(u.id) {
232            panic!("fragment {} was already matched", u.id);
233        }
234
235        // Collect the internal tables of a node (not visiting children).
236        let collect_tables = |x: &StreamNode| {
237            let mut tables = Vec::new();
238            visit_stream_node_tables_inner(&mut x.clone(), true, false, |table, name| {
239                tables.push((name.to_owned(), table.clone()));
240            });
241            tables
242        };
243
244        let mut table_matches = HashMap::new();
245
246        // Use BFS to match the operator nodes.
247        let mut uq = VecDeque::from([&u.root]);
248        let mut vq = VecDeque::from([&v.root]);
249
250        while let Some(mut un) = uq.pop_front() {
251            // Since we ensure the number of inputs of an operator is the same before extending
252            // the BFS queue, we can safely unwrap here.
253            let mut vn = vq.pop_front().unwrap();
254
255            // Skip while the node is stateless and has only one input.
256            let mut u_tables = collect_tables(un);
257            while u_tables.is_empty() && un.input.len() == 1 {
258                un = &un.input[0];
259                u_tables = collect_tables(un);
260            }
261            let mut v_tables = collect_tables(vn);
262            while v_tables.is_empty() && vn.input.len() == 1 {
263                vn = &vn.input[0];
264                v_tables = collect_tables(vn);
265            }
266
267            // If we reach the leaf node, we are done of this fragment.
268            if un.input.is_empty() && vn.input.is_empty() {
269                continue;
270            }
271
272            // Perform checks.
273            if un.node_body.as_ref().unwrap().discriminant()
274                != vn.node_body.as_ref().unwrap().discriminant()
275            {
276                bail_operator!(from = un, to = vn, "operator has different type");
277            }
278            if un.input.len() != vn.input.len() {
279                bail_operator!(
280                    from = un,
281                    to = vn,
282                    "operator has different number of inputs ({} vs {})",
283                    un.input.len(),
284                    vn.input.len()
285                );
286            }
287
288            // Extend the BFS queue.
289            uq.extend(un.input.iter());
290            vq.extend(vn.input.iter());
291
292            for (ut_name, ut) in u_tables {
293                let vt_cands = v_tables
294                    .extract_if(.., |(vt_name, _)| *vt_name == ut_name)
295                    .collect_vec();
296
297                if vt_cands.is_empty() {
298                    bail_operator!(
299                        from = un,
300                        to = vn,
301                        "cannot find a match for table `{ut_name}`",
302                    );
303                } else if vt_cands.len() > 1 {
304                    bail_operator!(
305                        from = un,
306                        to = vn,
307                        "found multiple matches for table `{ut_name}`",
308                    );
309                }
310
311                let (_, vt) = vt_cands.into_iter().next().unwrap();
312
313                // Since the requirement is to ensure the state compatibility, we focus solely on
314                // the "physical" part of the table, best illustrated by `TableDesc`.
315                let table_desc_for_compare = |table: &PbTable| {
316                    let mut table = table.clone();
317                    table.id = 0; // ignore id
318                    table.maybe_vnode_count = Some(42); // vnode count is unfilled for new fragment graph, fill it with a dummy value before proceeding
319
320                    TableDesc::from_pb_table(&table)
321                };
322
323                let ut_compare = table_desc_for_compare(&ut);
324                let vt_compare = table_desc_for_compare(&vt);
325
326                if ut_compare != vt_compare {
327                    bail_operator!(
328                        from = un,
329                        to = vn,
330                        "found a match for table `{ut_name}`, but they are incompatible, diff:\n{}",
331                        pretty_assertions::Comparison::new(&ut_compare, &vt_compare)
332                    );
333                }
334
335                table_matches.try_insert(ut.id, vt.id).unwrap_or_else(|_| {
336                    panic!("duplicated table id {} in fragment {}", ut.id, u.id)
337                });
338            }
339        }
340
341        let m = Match {
342            target: v.id,
343            table_matches,
344        };
345        self.inner.insert(u.id, m);
346        self.matched_targets.insert(v.id);
347
348        Ok(())
349    }
350
351    /// Undoes the match of a source fragment.
352    fn undo_match(&mut self, u: Id) {
353        let target = self
354            .inner
355            .remove(&u)
356            .unwrap_or_else(|| panic!("fragment {} was not previously matched", u))
357            .target;
358
359        let target_removed = self.matched_targets.remove(&target);
360        assert!(target_removed);
361    }
362
363    /// Converts the match result into a table mapping.
364    fn into_table_mapping(self) -> HashMap<u32, u32> {
365        self.inner
366            .into_iter()
367            .flat_map(|(_, m)| m.table_matches.into_iter())
368            .collect()
369    }
370}
371
372/// Matches two [`Graph`]s, and returns the match result from each fragment in `g1` to `g2`.
373fn match_graph(g1: &Graph, g2: &Graph) -> Result<Matches> {
374    if g1.len() != g2.len() {
375        bail_graph!(
376            "graphs have different number of fragments ({} vs {})",
377            g1.len(),
378            g2.len()
379        );
380    }
381
382    let fps1 = g1.fingerprints()?;
383    let fps2 = g2.fingerprints()?;
384
385    // Collect the candidates for each fragment.
386    let mut fp_cand = HashMap::with_capacity(g1.len());
387    for (&u, &f1) in &fps1 {
388        for (&v, &f2) in &fps2 {
389            if f1 == f2 {
390                fp_cand.entry(u).or_insert_with(HashSet::new).insert(v);
391            }
392        }
393    }
394
395    fn dfs(
396        g1: &Graph,
397        g2: &Graph,
398        fp_cand: &mut HashMap<Id, HashSet<Id>>,
399        matches: &mut Matches,
400    ) -> Result<()> {
401        // If all fragments are matched, return.
402        if matches.len() == g1.len() {
403            return Ok(());
404        }
405
406        // Choose fragment with fewest remaining candidates that's not matched.
407        let (&u, u_cands) = fp_cand
408            .iter()
409            .filter(|(u, _)| !matches.matched(**u))
410            .min_by_key(|(_, cands)| cands.len())
411            .unwrap();
412        let u_cands = u_cands.clone();
413
414        let mut last_error = None;
415
416        for &v in &u_cands {
417            // Skip if v is already used.
418            if matches.target_matched(v) {
419                continue;
420            }
421
422            // For each upstream of u, if it's already matched, then it must be matched to the corresponding v's upstream.
423            let upstreams = g1.upstreams(u).to_vec();
424            for u_upstream in upstreams {
425                if let Some(v_upstream) = matches.target(u_upstream)
426                    && !g2.upstreams(v).contains(&v_upstream)
427                {
428                    // Not a valid match.
429                    continue;
430                }
431            }
432            // Same for downstream of u.
433            let downstreams = g1.downstreams(u).to_vec();
434            for u_downstream in downstreams {
435                if let Some(v_downstream) = matches.target(u_downstream)
436                    && !g2.downstreams(v).contains(&v_downstream)
437                {
438                    // Not a valid match.
439                    continue;
440                }
441            }
442
443            // Now that `u` and `v` are in the same position of the graph, try to match them by visiting the operators.
444            match matches.try_match(&g1.nodes[&u], &g2.nodes[&v]) {
445                Ok(()) => {
446                    let fp_cand_clone = fp_cand.clone();
447
448                    // v cannot be a candidate for any other u. Remove it before proceeding.
449                    for (_, u_cands) in fp_cand.iter_mut() {
450                        u_cands.remove(&v);
451                    }
452
453                    // Try to match the rest.
454                    match dfs(g1, g2, fp_cand, matches) {
455                        Ok(()) => return Ok(()), // success, return
456                        Err(err) => {
457                            last_error = Some(err);
458
459                            // Backtrack.
460                            *fp_cand = fp_cand_clone;
461                            matches.undo_match(u);
462                        }
463                    }
464                }
465
466                Err(err) => last_error = Some(err),
467            }
468        }
469
470        if let Some(error) = last_error {
471            bail_fragment!(
472                source = Box::new(error),
473                id = u,
474                "tried against all {} candidates, but failed",
475                u_cands.len()
476            );
477        } else {
478            bail_fragment!(
479                id = u,
480                "cannot find a candidate with same topological position"
481            )
482        }
483    }
484
485    let mut matches = Matches::new();
486    dfs(g1, g2, &mut fp_cand, &mut matches)?;
487    Ok(matches)
488}
489
490/// Matches two [`Graph`]s, and returns the internal table mapping from `g1` to `g2`.
491pub(crate) fn match_graph_internal_tables(g1: &Graph, g2: &Graph) -> Result<HashMap<u32, u32>> {
492    match_graph(g1, g2).map(|matches| matches.into_table_mapping())
493}
494
495impl Graph {
496    /// Creates a [`Graph`] from a [`StreamFragmentGraph`] that's being built.
497    pub(crate) fn from_building(graph: &StreamFragmentGraph) -> Self {
498        let nodes = graph
499            .fragments
500            .iter()
501            .map(|(&id, f)| {
502                (
503                    id.as_global_id(),
504                    Fragment {
505                        id: id.as_global_id(),
506                        root: f.node.clone().unwrap(),
507                    },
508                )
509            })
510            .collect();
511
512        let downstreams = graph
513            .downstreams
514            .iter()
515            .map(|(&id, downstreams)| {
516                (
517                    id.as_global_id(),
518                    downstreams
519                        .iter()
520                        .map(|(&id, _)| id.as_global_id())
521                        .collect(),
522                )
523            })
524            .collect();
525
526        let upstreams = graph
527            .upstreams
528            .iter()
529            .map(|(&id, upstreams)| {
530                (
531                    id.as_global_id(),
532                    upstreams.iter().map(|(&id, _)| id.as_global_id()).collect(),
533                )
534            })
535            .collect();
536
537        Self {
538            nodes,
539            downstreams,
540            upstreams,
541        }
542    }
543
544    /// Creates a [`Graph`] from a [`StreamJobFragments`] that's existing.
545    pub(crate) fn from_existing(
546        fragments: &StreamJobFragments,
547        fragment_upstreams: &HashMap<u32, HashSet<u32>>,
548    ) -> Self {
549        let nodes: HashMap<_, _> = fragments
550            .fragments
551            .iter()
552            .map(|(&id, f)| {
553                (
554                    id,
555                    Fragment {
556                        id,
557                        root: f.nodes.clone(),
558                    },
559                )
560            })
561            .collect();
562
563        let mut downstreams = HashMap::new();
564        let mut upstreams = HashMap::new();
565
566        for (&id, fragment_upstreams) in fragment_upstreams {
567            assert!(nodes.contains_key(&id));
568
569            for &upstream in fragment_upstreams {
570                if !nodes.contains_key(&upstream) {
571                    // Upstream fragment can be from a different job, ignore it.
572                    continue;
573                }
574                downstreams
575                    .entry(upstream)
576                    .or_insert_with(Vec::new)
577                    .push(id);
578                upstreams.entry(id).or_insert_with(Vec::new).push(upstream);
579            }
580        }
581
582        Self {
583            nodes,
584            downstreams,
585            upstreams,
586        }
587    }
588}