risingwave_meta/stream/stream_graph/
state_match.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains the logic for matching the internal state tables of two streaming jobs,
16//! used for replacing a streaming job (typically `ALTER MV`) while preserving the existing state.
17
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::hash::{DefaultHasher, Hash as _, Hasher as _};
20
21use itertools::Itertools;
22use risingwave_common::catalog::{TableDesc, TableId};
23use risingwave_common::id::FragmentId;
24use risingwave_common::util::stream_graph_visitor::visit_stream_node_tables_inner;
25use risingwave_pb::catalog::PbTable;
26use risingwave_pb::id::StreamNodeLocalOperatorId;
27use risingwave_pb::stream_plan::stream_node::PbNodeBody;
28use risingwave_pb::stream_plan::{PbStreamScanType, StreamNode};
29use strum::IntoDiscriminant;
30
31use crate::model::StreamJobFragments;
32use crate::stream::StreamFragmentGraph;
33
34/// Helper type for describing a [`StreamNode`] in error messages.
35pub(crate) struct StreamNodeDesc(Box<str>);
36
37impl From<&StreamNode> for StreamNodeDesc {
38    fn from(node: &StreamNode) -> Self {
39        let id = node.operator_id;
40        let identity = &node.identity;
41        let body = node.node_body.as_ref().unwrap();
42
43        Self(format!("{}({}, {})", body, id, identity).into_boxed_str())
44    }
45}
46
47impl std::fmt::Display for StreamNodeDesc {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}", self.0)
50    }
51}
52
53/// Error type for failed state table matching.
54#[derive(thiserror::Error, thiserror_ext::Macro, thiserror_ext::ReportDebug)]
55pub(crate) enum Error {
56    #[error("failed to match graph: {message}")]
57    Graph { message: String },
58
59    #[error("failed to match fragment {id}: {message}")]
60    Fragment {
61        source: Option<Box<Error>>,
62        id: Id,
63        message: String,
64    },
65
66    #[error("failed to match operator {from} to {to}: {message}")]
67    Operator {
68        from: StreamNodeDesc,
69        to: StreamNodeDesc,
70        message: String,
71    },
72}
73
74type Result<T, E = Error> = std::result::Result<T, E>;
75
76/// Fragment id.
77type Id = FragmentId;
78
79/// Node for a fragment in the [`Graph`].
80struct Fragment {
81    /// The fragment id.
82    id: Id,
83    /// The root node of the fragment.
84    root: StreamNode,
85}
86
87/// A streaming job graph that's used for state table matching.
88pub struct Graph {
89    /// All fragments in the graph.
90    nodes: HashMap<Id, Fragment>,
91    /// Downstreams of each fragment.
92    downstreams: HashMap<Id, Vec<Id>>,
93    /// Upstreams of each fragment.
94    upstreams: HashMap<Id, Vec<Id>>,
95}
96
97impl Graph {
98    /// Returns the number of fragments in the graph.
99    fn len(&self) -> usize {
100        self.nodes.len()
101    }
102
103    /// Returns the downstreams of a fragment.
104    fn downstreams(&self, id: Id) -> &[Id] {
105        self.downstreams.get(&id).map_or(&[], |v| v.as_slice())
106    }
107
108    /// Returns the upstreams of a fragment.
109    fn upstreams(&self, id: Id) -> &[Id] {
110        self.upstreams.get(&id).map_or(&[], |v| v.as_slice())
111    }
112
113    /// Returns the topological order of the graph.
114    fn topo_order(&self) -> Result<Vec<Id>> {
115        let mut topo = Vec::new();
116        let mut downstream_cnts = HashMap::new();
117
118        // Iterate all nodes to find the root and initialize the downstream counts.
119        for node_id in self.nodes.keys() {
120            let downstream_cnt = self.downstreams(*node_id).len();
121            if downstream_cnt == 0 {
122                topo.push(*node_id);
123            } else {
124                downstream_cnts.insert(*node_id, downstream_cnt);
125            }
126        }
127
128        let mut i = 0;
129        while let Some(&node_id) = topo.get(i) {
130            i += 1;
131            // Find if we can process more nodes.
132            for &upstream_id in self.upstreams(node_id) {
133                let downstream_cnt = downstream_cnts.get_mut(&upstream_id).unwrap();
134                *downstream_cnt -= 1;
135                if *downstream_cnt == 0 {
136                    downstream_cnts.remove(&upstream_id);
137                    topo.push(upstream_id);
138                }
139            }
140        }
141
142        if !downstream_cnts.is_empty() {
143            // There are nodes that are not processed yet.
144            bail_graph!("fragment graph is not a DAG");
145        }
146        assert_eq!(topo.len(), self.len());
147
148        Ok(topo)
149    }
150
151    /// Calculates the fingerprints of all fragments based on their position in the graph.
152    ///
153    /// This is used to locate the candidates when matching a fragment against another graph.
154    fn fingerprints(&self) -> Result<HashMap<Id, u64>> {
155        let mut fps = HashMap::new();
156
157        let order = self.topo_order()?;
158        for u in order.into_iter().rev() {
159            let upstream_fps = self
160                .upstreams(u)
161                .iter()
162                .map(|id| *fps.get(id).unwrap())
163                .sorted() // allow input order to be arbitrary
164                .collect_vec();
165
166            // Hash the downstream count, upstream count, and upstream fingerprints to
167            // generate the fingerprint of this node.
168            let mut hasher = DefaultHasher::new();
169            (
170                self.upstreams(u).len(),
171                self.downstreams(u).len(),
172                upstream_fps,
173            )
174                .hash(&mut hasher);
175            let fingerprint = hasher.finish();
176
177            fps.insert(u, fingerprint);
178        }
179
180        Ok(fps)
181    }
182}
183
184#[derive(Default)]
185pub(crate) struct MatchResult {
186    /// The mapping from source table id to target table within the fragment.
187    pub table_matches: HashMap<TableId, PbTable>,
188    /// The mapping from source `operator_id` to snapshot epoch
189    pub snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
190}
191
192/// The match result of a fragment in the source graph to a fragment in the target graph.
193struct Match {
194    /// The target fragment id.
195    target: Id,
196    result: MatchResult,
197}
198
199/// The successful matching result of two [`Graph`]s.
200struct Matches {
201    /// The mapping from source fragment id to target fragment id.
202    inner: HashMap<Id, Match>,
203    /// The set of target fragment ids that are already matched.
204    matched_targets: HashSet<Id>,
205}
206
207impl Matches {
208    /// Creates a new empty match result.
209    fn new() -> Self {
210        Self {
211            inner: HashMap::new(),
212            matched_targets: HashSet::new(),
213        }
214    }
215
216    /// Returns the target fragment id of a source fragment id.
217    fn target(&self, u: Id) -> Option<Id> {
218        self.inner.get(&u).map(|m| m.target)
219    }
220
221    /// Returns the number of matched fragments.
222    fn len(&self) -> usize {
223        self.inner.len()
224    }
225
226    /// Returns true if the source fragment is already matched.
227    fn matched(&self, u: Id) -> bool {
228        self.inner.contains_key(&u)
229    }
230
231    /// Returns true if the target fragment is already matched.
232    fn target_matched(&self, v: Id) -> bool {
233        self.matched_targets.contains(&v)
234    }
235
236    /// Tries to match a source fragment to a target fragment. If successful, they will be recorded
237    /// in the match result.
238    ///
239    /// This will check the operators and internal tables of the fragment.
240    fn try_match(&mut self, u: &Fragment, v: &Fragment) -> Result<()> {
241        if self.matched(u.id) {
242            panic!("fragment {} was already matched", u.id);
243        }
244
245        // Collect the internal tables of a node (not visiting children).
246        let collect_tables = |x: &StreamNode| {
247            let mut tables = Vec::new();
248            visit_stream_node_tables_inner(&mut x.clone(), true, false, |table, name| {
249                tables.push((name.to_owned(), table.clone()));
250            });
251            tables
252        };
253
254        let mut result = MatchResult::default();
255
256        // Use BFS to match the operator nodes.
257        let mut uq = VecDeque::from([&u.root]);
258        let mut vq = VecDeque::from([&v.root]);
259
260        while let Some(mut un) = uq.pop_front() {
261            // Since we ensure the number of inputs of an operator is the same before extending
262            // the BFS queue, we can safely unwrap here.
263            let mut vn = vq.pop_front().unwrap();
264
265            // Skip while the node is stateless and has only one input.
266            let mut u_tables = collect_tables(un);
267            while u_tables.is_empty() && un.input.len() == 1 {
268                un = &un.input[0];
269                u_tables = collect_tables(un);
270            }
271            let mut v_tables = collect_tables(vn);
272            while v_tables.is_empty() && vn.input.len() == 1 {
273                vn = &vn.input[0];
274                v_tables = collect_tables(vn);
275            }
276
277            // If we reach the leaf node, we are done of this fragment.
278            if un.input.is_empty() && vn.input.is_empty() {
279                continue;
280            }
281
282            // Perform checks.
283            if un.node_body.as_ref().unwrap().discriminant()
284                != vn.node_body.as_ref().unwrap().discriminant()
285            {
286                bail_operator!(from = un, to = vn, "operator has different type");
287            }
288            if let PbNodeBody::StreamScan(uscan) = un.node_body.as_ref().unwrap() {
289                let PbNodeBody::StreamScan(vscan) = vn.node_body.as_ref().unwrap() else {
290                    unreachable!("checked same discriminant");
291                };
292                if let scan_type @ (PbStreamScanType::SnapshotBackfill
293                | PbStreamScanType::CrossDbSnapshotBackfill) = uscan.stream_scan_type()
294                {
295                    let Some(snapshot_epoch) = vscan.snapshot_backfill_epoch else {
296                        bail_operator!(
297                            from = un,
298                            to = vn,
299                            "expect snapshot_backfill_epoch set for new stream_scan_type {:?} with old stream_scan_type {:?}",
300                            scan_type,
301                            vscan.stream_scan_type()
302                        );
303                    };
304                    result
305                        .snapshot_backfill_epochs
306                        .try_insert(un.operator_id, snapshot_epoch)
307                        .unwrap();
308                }
309            }
310            if un.input.len() != vn.input.len() {
311                bail_operator!(
312                    from = un,
313                    to = vn,
314                    "operator has different number of inputs ({} vs {})",
315                    un.input.len(),
316                    vn.input.len()
317                );
318            }
319
320            // Extend the BFS queue.
321            uq.extend(un.input.iter());
322            vq.extend(vn.input.iter());
323
324            for (ut_name, ut) in u_tables {
325                let vt_cands = v_tables
326                    .extract_if(.., |(vt_name, _)| *vt_name == ut_name)
327                    .collect_vec();
328
329                if vt_cands.is_empty() {
330                    bail_operator!(
331                        from = un,
332                        to = vn,
333                        "cannot find a match for table `{ut_name}`",
334                    );
335                } else if vt_cands.len() > 1 {
336                    bail_operator!(
337                        from = un,
338                        to = vn,
339                        "found multiple matches for table `{ut_name}`",
340                    );
341                }
342
343                let (_, vt) = vt_cands.into_iter().next().unwrap();
344
345                // Since the requirement is to ensure the state compatibility, we focus solely on
346                // the "physical" part of the table, best illustrated by `TableDesc`.
347                let table_desc_for_compare = |table: &PbTable| {
348                    let mut table = table.clone();
349                    table.id = 0.into(); // ignore id
350                    table.maybe_vnode_count = Some(42); // vnode count is unfilled for new fragment graph, fill it with a dummy value before proceeding
351
352                    TableDesc::from_pb_table(&table)
353                };
354
355                let ut_compare = table_desc_for_compare(&ut);
356                let vt_compare = table_desc_for_compare(&vt);
357
358                if ut_compare != vt_compare {
359                    bail_operator!(
360                        from = un,
361                        to = vn,
362                        "found a match for table `{ut_name}`, but they are incompatible, diff:\n{}",
363                        pretty_assertions::Comparison::new(&ut_compare, &vt_compare)
364                    );
365                }
366
367                result
368                    .table_matches
369                    .try_insert(ut.id, vt)
370                    .unwrap_or_else(|_| {
371                        panic!("duplicated table id {} in fragment {}", ut.id, u.id)
372                    });
373            }
374        }
375
376        let m = Match {
377            target: v.id,
378            result,
379        };
380        self.inner.insert(u.id, m);
381        self.matched_targets.insert(v.id);
382
383        Ok(())
384    }
385
386    /// Undoes the match of a source fragment.
387    fn undo_match(&mut self, u: Id) {
388        let target = self
389            .inner
390            .remove(&u)
391            .unwrap_or_else(|| panic!("fragment {} was not previously matched", u))
392            .target;
393
394        let target_removed = self.matched_targets.remove(&target);
395        assert!(target_removed);
396    }
397
398    /// Converts the match result into a table mapping.
399    fn into_match_result(self) -> MatchResult {
400        let mut result = MatchResult::default();
401        for matches in self.inner.into_values() {
402            for (table_id, table) in matches.result.table_matches {
403                result
404                    .table_matches
405                    .try_insert(table_id, table)
406                    .expect("non-duplicated");
407            }
408            for (operator_id, epoch) in matches.result.snapshot_backfill_epochs {
409                result
410                    .snapshot_backfill_epochs
411                    .try_insert(operator_id, epoch)
412                    .expect("non-duplicated");
413            }
414        }
415        result
416    }
417}
418
419/// Matches two [`Graph`]s, and returns the match result from each fragment in `g1` to `g2`.
420pub(crate) fn match_graph(g1: &Graph, g2: &Graph) -> Result<MatchResult> {
421    if g1.len() != g2.len() {
422        bail_graph!(
423            "graphs have different number of fragments ({} vs {})",
424            g1.len(),
425            g2.len()
426        );
427    }
428
429    let fps1 = g1.fingerprints()?;
430    let fps2 = g2.fingerprints()?;
431
432    // Collect the candidates for each fragment.
433    let mut fp_cand = HashMap::with_capacity(g1.len());
434    for (&u, &f1) in &fps1 {
435        for (&v, &f2) in &fps2 {
436            if f1 == f2 {
437                fp_cand.entry(u).or_insert_with(HashSet::new).insert(v);
438            }
439        }
440    }
441
442    fn dfs(
443        g1: &Graph,
444        g2: &Graph,
445        fp_cand: &mut HashMap<Id, HashSet<Id>>,
446        matches: &mut Matches,
447    ) -> Result<()> {
448        // If all fragments are matched, return.
449        if matches.len() == g1.len() {
450            return Ok(());
451        }
452
453        // Choose fragment with fewest remaining candidates that's not matched.
454        let (&u, u_cands) = fp_cand
455            .iter()
456            .filter(|(u, _)| !matches.matched(**u))
457            .min_by_key(|(_, cands)| cands.len())
458            .unwrap();
459        let u_cands = u_cands.clone();
460
461        let mut last_error = None;
462
463        'cand_v: for &v in &u_cands {
464            // Skip if v is already used.
465            if matches.target_matched(v) {
466                continue;
467            }
468
469            // For each upstream of u, if it's already matched, then it must be matched to the corresponding v's upstream.
470            let upstreams = g1.upstreams(u).to_vec();
471            for u_upstream in upstreams {
472                if let Some(v_upstream) = matches.target(u_upstream)
473                    && !g2.upstreams(v).contains(&v_upstream)
474                {
475                    // Not a valid match.
476                    continue 'cand_v;
477                }
478            }
479            // Same for downstream of u.
480            let downstreams = g1.downstreams(u).to_vec();
481            for u_downstream in downstreams {
482                if let Some(v_downstream) = matches.target(u_downstream)
483                    && !g2.downstreams(v).contains(&v_downstream)
484                {
485                    // Not a valid match.
486                    continue 'cand_v;
487                }
488            }
489
490            // Now that `u` and `v` are in the same position of the graph, try to match them by visiting the operators.
491            match matches.try_match(&g1.nodes[&u], &g2.nodes[&v]) {
492                Ok(()) => {
493                    let fp_cand_clone = fp_cand.clone();
494
495                    // v cannot be a candidate for any other u. Remove it before proceeding.
496                    for (_, u_cands) in fp_cand.iter_mut() {
497                        u_cands.remove(&v);
498                    }
499
500                    // Try to match the rest.
501                    match dfs(g1, g2, fp_cand, matches) {
502                        Ok(()) => return Ok(()), // success, return
503                        Err(err) => {
504                            last_error = Some(err);
505
506                            // Backtrack.
507                            *fp_cand = fp_cand_clone;
508                            matches.undo_match(u);
509                        }
510                    }
511                }
512
513                Err(err) => last_error = Some(err),
514            }
515        }
516
517        if let Some(error) = last_error {
518            bail_fragment!(
519                source = Box::new(error),
520                id = u,
521                "tried against all {} candidates, but failed",
522                u_cands.len()
523            );
524        } else {
525            bail_fragment!(
526                id = u,
527                "cannot find a candidate with same topological position"
528            )
529        }
530    }
531
532    let mut matches = Matches::new();
533    dfs(g1, g2, &mut fp_cand, &mut matches)?;
534    Ok(matches.into_match_result())
535}
536
537impl Graph {
538    /// Creates a [`Graph`] from a [`StreamFragmentGraph`] that's being built.
539    pub(crate) fn from_building(graph: &StreamFragmentGraph) -> Self {
540        let nodes = graph
541            .fragments
542            .iter()
543            .map(|(&id, f)| {
544                let id = id.as_global_id();
545                (
546                    id,
547                    Fragment {
548                        id,
549                        root: f.node.clone().unwrap(),
550                    },
551                )
552            })
553            .collect();
554
555        let downstreams = graph
556            .downstreams
557            .iter()
558            .map(|(&id, downstreams)| {
559                (
560                    id.as_global_id(),
561                    downstreams
562                        .iter()
563                        .map(|(&id, _)| id.as_global_id())
564                        .collect(),
565                )
566            })
567            .collect();
568
569        let upstreams = graph
570            .upstreams
571            .iter()
572            .map(|(&id, upstreams)| {
573                (
574                    id.as_global_id(),
575                    upstreams.iter().map(|(&id, _)| id.as_global_id()).collect(),
576                )
577            })
578            .collect();
579
580        Self {
581            nodes,
582            downstreams,
583            upstreams,
584        }
585    }
586
587    /// Creates a [`Graph`] from a [`StreamJobFragments`] that's existing.
588    pub(crate) fn from_existing(
589        fragments: &StreamJobFragments,
590        fragment_upstreams: &HashMap<Id, HashSet<Id>>,
591    ) -> Self {
592        let nodes: HashMap<_, _> = fragments
593            .fragments
594            .iter()
595            .map(|(&id, f)| {
596                (
597                    id,
598                    Fragment {
599                        id,
600                        root: f.nodes.clone(),
601                    },
602                )
603            })
604            .collect();
605
606        let mut downstreams = HashMap::new();
607        let mut upstreams = HashMap::new();
608
609        for (&id, fragment_upstreams) in fragment_upstreams {
610            assert!(nodes.contains_key(&id));
611
612            for &upstream in fragment_upstreams {
613                if !nodes.contains_key(&upstream) {
614                    // Upstream fragment can be from a different job, ignore it.
615                    continue;
616                }
617                downstreams
618                    .entry(upstream)
619                    .or_insert_with(Vec::new)
620                    .push(id);
621                upstreams.entry(id).or_insert_with(Vec::new).push(upstream);
622            }
623        }
624
625        Self {
626            nodes,
627            downstreams,
628            upstreams,
629        }
630    }
631}