risingwave_meta/stream/stream_graph/
state_match.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains the logic for matching the internal state tables of two streaming jobs,
16//! used for replacing a streaming job (typically `ALTER MV`) while preserving the existing state.
17
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::hash::{DefaultHasher, Hash as _, Hasher as _};
20
21use itertools::Itertools;
22use risingwave_common::catalog::{TableDesc, TableId};
23use risingwave_common::id::FragmentId;
24use risingwave_common::util::stream_graph_visitor::visit_stream_node_tables_inner;
25use risingwave_pb::catalog::PbTable;
26use risingwave_pb::stream_plan::StreamNode;
27use strum::IntoDiscriminant;
28
29use crate::model::StreamJobFragments;
30use crate::stream::StreamFragmentGraph;
31
32/// Helper type for describing a [`StreamNode`] in error messages.
33pub(crate) struct StreamNodeDesc(Box<str>);
34
35impl From<&StreamNode> for StreamNodeDesc {
36    fn from(node: &StreamNode) -> Self {
37        let id = node.operator_id;
38        let identity = &node.identity;
39        let body = node.node_body.as_ref().unwrap();
40
41        Self(format!("{}({}, {})", body, id, identity).into_boxed_str())
42    }
43}
44
45impl std::fmt::Display for StreamNodeDesc {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "{}", self.0)
48    }
49}
50
51/// Error type for failed state table matching.
52#[derive(thiserror::Error, thiserror_ext::Macro, thiserror_ext::ReportDebug)]
53pub(crate) enum Error {
54    #[error("failed to match graph: {message}")]
55    Graph { message: String },
56
57    #[error("failed to match fragment {id}: {message}")]
58    Fragment {
59        source: Option<Box<Error>>,
60        id: Id,
61        message: String,
62    },
63
64    #[error("failed to match operator {from} to {to}: {message}")]
65    Operator {
66        from: StreamNodeDesc,
67        to: StreamNodeDesc,
68        message: String,
69    },
70}
71
72type Result<T, E = Error> = std::result::Result<T, E>;
73
74/// Fragment id.
75type Id = FragmentId;
76
77/// Node for a fragment in the [`Graph`].
78struct Fragment {
79    /// The fragment id.
80    id: Id,
81    /// The root node of the fragment.
82    root: StreamNode,
83}
84
85/// A streaming job graph that's used for state table matching.
86pub struct Graph {
87    /// All fragments in the graph.
88    nodes: HashMap<Id, Fragment>,
89    /// Downstreams of each fragment.
90    downstreams: HashMap<Id, Vec<Id>>,
91    /// Upstreams of each fragment.
92    upstreams: HashMap<Id, Vec<Id>>,
93}
94
95impl Graph {
96    /// Returns the number of fragments in the graph.
97    fn len(&self) -> usize {
98        self.nodes.len()
99    }
100
101    /// Returns the downstreams of a fragment.
102    fn downstreams(&self, id: Id) -> &[Id] {
103        self.downstreams.get(&id).map_or(&[], |v| v.as_slice())
104    }
105
106    /// Returns the upstreams of a fragment.
107    fn upstreams(&self, id: Id) -> &[Id] {
108        self.upstreams.get(&id).map_or(&[], |v| v.as_slice())
109    }
110
111    /// Returns the topological order of the graph.
112    fn topo_order(&self) -> Result<Vec<Id>> {
113        let mut topo = Vec::new();
114        let mut downstream_cnts = HashMap::new();
115
116        // Iterate all nodes to find the root and initialize the downstream counts.
117        for node_id in self.nodes.keys() {
118            let downstream_cnt = self.downstreams(*node_id).len();
119            if downstream_cnt == 0 {
120                topo.push(*node_id);
121            } else {
122                downstream_cnts.insert(*node_id, downstream_cnt);
123            }
124        }
125
126        let mut i = 0;
127        while let Some(&node_id) = topo.get(i) {
128            i += 1;
129            // Find if we can process more nodes.
130            for &upstream_id in self.upstreams(node_id) {
131                let downstream_cnt = downstream_cnts.get_mut(&upstream_id).unwrap();
132                *downstream_cnt -= 1;
133                if *downstream_cnt == 0 {
134                    downstream_cnts.remove(&upstream_id);
135                    topo.push(upstream_id);
136                }
137            }
138        }
139
140        if !downstream_cnts.is_empty() {
141            // There are nodes that are not processed yet.
142            bail_graph!("fragment graph is not a DAG");
143        }
144        assert_eq!(topo.len(), self.len());
145
146        Ok(topo)
147    }
148
149    /// Calculates the fingerprints of all fragments based on their position in the graph.
150    ///
151    /// This is used to locate the candidates when matching a fragment against another graph.
152    fn fingerprints(&self) -> Result<HashMap<Id, u64>> {
153        let mut fps = HashMap::new();
154
155        let order = self.topo_order()?;
156        for u in order.into_iter().rev() {
157            let upstream_fps = self
158                .upstreams(u)
159                .iter()
160                .map(|id| *fps.get(id).unwrap())
161                .sorted() // allow input order to be arbitrary
162                .collect_vec();
163
164            // Hash the downstream count, upstream count, and upstream fingerprints to
165            // generate the fingerprint of this node.
166            let mut hasher = DefaultHasher::new();
167            (
168                self.upstreams(u).len(),
169                self.downstreams(u).len(),
170                upstream_fps,
171            )
172                .hash(&mut hasher);
173            let fingerprint = hasher.finish();
174
175            fps.insert(u, fingerprint);
176        }
177
178        Ok(fps)
179    }
180}
181
182/// The match result of a fragment in the source graph to a fragment in the target graph.
183struct Match {
184    /// The target fragment id.
185    target: Id,
186    /// The mapping from source table id to target table within the fragment.
187    table_matches: HashMap<TableId, PbTable>,
188}
189
190/// The successful matching result of two [`Graph`]s.
191struct Matches {
192    /// The mapping from source fragment id to target fragment id.
193    inner: HashMap<Id, Match>,
194    /// The set of target fragment ids that are already matched.
195    matched_targets: HashSet<Id>,
196}
197
198impl Matches {
199    /// Creates a new empty match result.
200    fn new() -> Self {
201        Self {
202            inner: HashMap::new(),
203            matched_targets: HashSet::new(),
204        }
205    }
206
207    /// Returns the target fragment id of a source fragment id.
208    fn target(&self, u: Id) -> Option<Id> {
209        self.inner.get(&u).map(|m| m.target)
210    }
211
212    /// Returns the number of matched fragments.
213    fn len(&self) -> usize {
214        self.inner.len()
215    }
216
217    /// Returns true if the source fragment is already matched.
218    fn matched(&self, u: Id) -> bool {
219        self.inner.contains_key(&u)
220    }
221
222    /// Returns true if the target fragment is already matched.
223    fn target_matched(&self, v: Id) -> bool {
224        self.matched_targets.contains(&v)
225    }
226
227    /// Tries to match a source fragment to a target fragment. If successful, they will be recorded
228    /// in the match result.
229    ///
230    /// This will check the operators and internal tables of the fragment.
231    fn try_match(&mut self, u: &Fragment, v: &Fragment) -> Result<()> {
232        if self.matched(u.id) {
233            panic!("fragment {} was already matched", u.id);
234        }
235
236        // Collect the internal tables of a node (not visiting children).
237        let collect_tables = |x: &StreamNode| {
238            let mut tables = Vec::new();
239            visit_stream_node_tables_inner(&mut x.clone(), true, false, |table, name| {
240                tables.push((name.to_owned(), table.clone()));
241            });
242            tables
243        };
244
245        let mut table_matches = HashMap::new();
246
247        // Use BFS to match the operator nodes.
248        let mut uq = VecDeque::from([&u.root]);
249        let mut vq = VecDeque::from([&v.root]);
250
251        while let Some(mut un) = uq.pop_front() {
252            // Since we ensure the number of inputs of an operator is the same before extending
253            // the BFS queue, we can safely unwrap here.
254            let mut vn = vq.pop_front().unwrap();
255
256            // Skip while the node is stateless and has only one input.
257            let mut u_tables = collect_tables(un);
258            while u_tables.is_empty() && un.input.len() == 1 {
259                un = &un.input[0];
260                u_tables = collect_tables(un);
261            }
262            let mut v_tables = collect_tables(vn);
263            while v_tables.is_empty() && vn.input.len() == 1 {
264                vn = &vn.input[0];
265                v_tables = collect_tables(vn);
266            }
267
268            // If we reach the leaf node, we are done of this fragment.
269            if un.input.is_empty() && vn.input.is_empty() {
270                continue;
271            }
272
273            // Perform checks.
274            if un.node_body.as_ref().unwrap().discriminant()
275                != vn.node_body.as_ref().unwrap().discriminant()
276            {
277                bail_operator!(from = un, to = vn, "operator has different type");
278            }
279            if un.input.len() != vn.input.len() {
280                bail_operator!(
281                    from = un,
282                    to = vn,
283                    "operator has different number of inputs ({} vs {})",
284                    un.input.len(),
285                    vn.input.len()
286                );
287            }
288
289            // Extend the BFS queue.
290            uq.extend(un.input.iter());
291            vq.extend(vn.input.iter());
292
293            for (ut_name, ut) in u_tables {
294                let vt_cands = v_tables
295                    .extract_if(.., |(vt_name, _)| *vt_name == ut_name)
296                    .collect_vec();
297
298                if vt_cands.is_empty() {
299                    bail_operator!(
300                        from = un,
301                        to = vn,
302                        "cannot find a match for table `{ut_name}`",
303                    );
304                } else if vt_cands.len() > 1 {
305                    bail_operator!(
306                        from = un,
307                        to = vn,
308                        "found multiple matches for table `{ut_name}`",
309                    );
310                }
311
312                let (_, vt) = vt_cands.into_iter().next().unwrap();
313
314                // Since the requirement is to ensure the state compatibility, we focus solely on
315                // the "physical" part of the table, best illustrated by `TableDesc`.
316                let table_desc_for_compare = |table: &PbTable| {
317                    let mut table = table.clone();
318                    table.id = 0.into(); // ignore id
319                    table.maybe_vnode_count = Some(42); // vnode count is unfilled for new fragment graph, fill it with a dummy value before proceeding
320
321                    TableDesc::from_pb_table(&table)
322                };
323
324                let ut_compare = table_desc_for_compare(&ut);
325                let vt_compare = table_desc_for_compare(&vt);
326
327                if ut_compare != vt_compare {
328                    bail_operator!(
329                        from = un,
330                        to = vn,
331                        "found a match for table `{ut_name}`, but they are incompatible, diff:\n{}",
332                        pretty_assertions::Comparison::new(&ut_compare, &vt_compare)
333                    );
334                }
335
336                table_matches.try_insert(ut.id, vt).unwrap_or_else(|_| {
337                    panic!("duplicated table id {} in fragment {}", ut.id, u.id)
338                });
339            }
340        }
341
342        let m = Match {
343            target: v.id,
344            table_matches,
345        };
346        self.inner.insert(u.id, m);
347        self.matched_targets.insert(v.id);
348
349        Ok(())
350    }
351
352    /// Undoes the match of a source fragment.
353    fn undo_match(&mut self, u: Id) {
354        let target = self
355            .inner
356            .remove(&u)
357            .unwrap_or_else(|| panic!("fragment {} was not previously matched", u))
358            .target;
359
360        let target_removed = self.matched_targets.remove(&target);
361        assert!(target_removed);
362    }
363
364    /// Converts the match result into a table mapping.
365    fn into_table_mapping(self) -> HashMap<TableId, PbTable> {
366        self.inner
367            .into_iter()
368            .flat_map(|(_, m)| m.table_matches.into_iter())
369            .collect()
370    }
371}
372
373/// Matches two [`Graph`]s, and returns the match result from each fragment in `g1` to `g2`.
374fn match_graph(g1: &Graph, g2: &Graph) -> Result<Matches> {
375    if g1.len() != g2.len() {
376        bail_graph!(
377            "graphs have different number of fragments ({} vs {})",
378            g1.len(),
379            g2.len()
380        );
381    }
382
383    let fps1 = g1.fingerprints()?;
384    let fps2 = g2.fingerprints()?;
385
386    // Collect the candidates for each fragment.
387    let mut fp_cand = HashMap::with_capacity(g1.len());
388    for (&u, &f1) in &fps1 {
389        for (&v, &f2) in &fps2 {
390            if f1 == f2 {
391                fp_cand.entry(u).or_insert_with(HashSet::new).insert(v);
392            }
393        }
394    }
395
396    fn dfs(
397        g1: &Graph,
398        g2: &Graph,
399        fp_cand: &mut HashMap<Id, HashSet<Id>>,
400        matches: &mut Matches,
401    ) -> Result<()> {
402        // If all fragments are matched, return.
403        if matches.len() == g1.len() {
404            return Ok(());
405        }
406
407        // Choose fragment with fewest remaining candidates that's not matched.
408        let (&u, u_cands) = fp_cand
409            .iter()
410            .filter(|(u, _)| !matches.matched(**u))
411            .min_by_key(|(_, cands)| cands.len())
412            .unwrap();
413        let u_cands = u_cands.clone();
414
415        let mut last_error = None;
416
417        for &v in &u_cands {
418            // Skip if v is already used.
419            if matches.target_matched(v) {
420                continue;
421            }
422
423            // For each upstream of u, if it's already matched, then it must be matched to the corresponding v's upstream.
424            let upstreams = g1.upstreams(u).to_vec();
425            for u_upstream in upstreams {
426                if let Some(v_upstream) = matches.target(u_upstream)
427                    && !g2.upstreams(v).contains(&v_upstream)
428                {
429                    // Not a valid match.
430                    continue;
431                }
432            }
433            // Same for downstream of u.
434            let downstreams = g1.downstreams(u).to_vec();
435            for u_downstream in downstreams {
436                if let Some(v_downstream) = matches.target(u_downstream)
437                    && !g2.downstreams(v).contains(&v_downstream)
438                {
439                    // Not a valid match.
440                    continue;
441                }
442            }
443
444            // Now that `u` and `v` are in the same position of the graph, try to match them by visiting the operators.
445            match matches.try_match(&g1.nodes[&u], &g2.nodes[&v]) {
446                Ok(()) => {
447                    let fp_cand_clone = fp_cand.clone();
448
449                    // v cannot be a candidate for any other u. Remove it before proceeding.
450                    for (_, u_cands) in fp_cand.iter_mut() {
451                        u_cands.remove(&v);
452                    }
453
454                    // Try to match the rest.
455                    match dfs(g1, g2, fp_cand, matches) {
456                        Ok(()) => return Ok(()), // success, return
457                        Err(err) => {
458                            last_error = Some(err);
459
460                            // Backtrack.
461                            *fp_cand = fp_cand_clone;
462                            matches.undo_match(u);
463                        }
464                    }
465                }
466
467                Err(err) => last_error = Some(err),
468            }
469        }
470
471        if let Some(error) = last_error {
472            bail_fragment!(
473                source = Box::new(error),
474                id = u,
475                "tried against all {} candidates, but failed",
476                u_cands.len()
477            );
478        } else {
479            bail_fragment!(
480                id = u,
481                "cannot find a candidate with same topological position"
482            )
483        }
484    }
485
486    let mut matches = Matches::new();
487    dfs(g1, g2, &mut fp_cand, &mut matches)?;
488    Ok(matches)
489}
490
491/// Matches two [`Graph`]s, and returns the internal table mapping from `g1` to `g2`.
492pub(crate) fn match_graph_internal_tables(
493    g1: &Graph,
494    g2: &Graph,
495) -> Result<HashMap<TableId, PbTable>> {
496    match_graph(g1, g2).map(|matches| matches.into_table_mapping())
497}
498
499impl Graph {
500    /// Creates a [`Graph`] from a [`StreamFragmentGraph`] that's being built.
501    pub(crate) fn from_building(graph: &StreamFragmentGraph) -> Self {
502        let nodes = graph
503            .fragments
504            .iter()
505            .map(|(&id, f)| {
506                let id = id.as_global_id();
507                (
508                    id,
509                    Fragment {
510                        id,
511                        root: f.node.clone().unwrap(),
512                    },
513                )
514            })
515            .collect();
516
517        let downstreams = graph
518            .downstreams
519            .iter()
520            .map(|(&id, downstreams)| {
521                (
522                    id.as_global_id(),
523                    downstreams
524                        .iter()
525                        .map(|(&id, _)| id.as_global_id())
526                        .collect(),
527                )
528            })
529            .collect();
530
531        let upstreams = graph
532            .upstreams
533            .iter()
534            .map(|(&id, upstreams)| {
535                (
536                    id.as_global_id(),
537                    upstreams.iter().map(|(&id, _)| id.as_global_id()).collect(),
538                )
539            })
540            .collect();
541
542        Self {
543            nodes,
544            downstreams,
545            upstreams,
546        }
547    }
548
549    /// Creates a [`Graph`] from a [`StreamJobFragments`] that's existing.
550    pub(crate) fn from_existing(
551        fragments: &StreamJobFragments,
552        fragment_upstreams: &HashMap<Id, HashSet<Id>>,
553    ) -> Self {
554        let nodes: HashMap<_, _> = fragments
555            .fragments
556            .iter()
557            .map(|(&id, f)| {
558                (
559                    id,
560                    Fragment {
561                        id,
562                        root: f.nodes.clone(),
563                    },
564                )
565            })
566            .collect();
567
568        let mut downstreams = HashMap::new();
569        let mut upstreams = HashMap::new();
570
571        for (&id, fragment_upstreams) in fragment_upstreams {
572            assert!(nodes.contains_key(&id));
573
574            for &upstream in fragment_upstreams {
575                if !nodes.contains_key(&upstream) {
576                    // Upstream fragment can be from a different job, ignore it.
577                    continue;
578                }
579                downstreams
580                    .entry(upstream)
581                    .or_insert_with(Vec::new)
582                    .push(id);
583                upstreams.entry(id).or_insert_with(Vec::new).push(upstream);
584            }
585        }
586
587        Self {
588            nodes,
589            downstreams,
590            upstreams,
591        }
592    }
593}