risingwave_frontend/optimizer/plan_node/
stream_locality_provider.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::XmlNode;
17use risingwave_common::catalog::Field;
18use risingwave_common::hash::VirtualNode;
19use risingwave_common::types::DataType;
20use risingwave_common::util::sort_util::OrderType;
21use risingwave_pb::stream_plan::LocalityProviderNode;
22use risingwave_pb::stream_plan::stream_node::PbNodeBody;
23
24use super::stream::prelude::*;
25use super::utils::{Distill, TableCatalogBuilder, childless_record};
26use super::{ExprRewritable, PlanTreeNodeUnary, StreamNode, StreamPlanRef as PlanRef, generic};
27use crate::TableCatalog;
28use crate::catalog::TableId;
29use crate::expr::{ExprRewriter, ExprVisitor};
30use crate::optimizer::plan_node::PlanBase;
31use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
32use crate::optimizer::property::Distribution;
33use crate::stream_fragmenter::BuildFragmentGraphState;
34
35/// `StreamLocalityProvider` implements [`super::LogicalLocalityProvider`]
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct StreamLocalityProvider {
38    pub base: PlanBase<Stream>,
39    core: generic::LocalityProvider<PlanRef>,
40}
41
42impl StreamLocalityProvider {
43    pub fn new(core: generic::LocalityProvider<PlanRef>) -> Self {
44        let input = core.input.clone();
45
46        let dist = match input.distribution() {
47            Distribution::HashShard(keys) => {
48                // If the input is hash-distributed, we make it a UpstreamHashShard distribution
49                // just like a normal table scan. It is used to ensure locality provider is in its own fragment.
50                // This is important to ensure the backfill ordering can recognize and build
51                // the dependency graph among different backfill-needed fragments.
52                Distribution::UpstreamHashShard(keys.clone(), TableId::placeholder())
53            }
54            Distribution::UpstreamHashShard(keys, table_id) => {
55                Distribution::UpstreamHashShard(keys.clone(), *table_id)
56            }
57            _ => {
58                panic!("LocalityProvider input must be hash-distributed");
59            }
60        };
61
62        // LocalityProvider maintains the append-only behavior if input is append-only
63        let base = PlanBase::new_stream_with_core(
64            &core,
65            dist,
66            input.stream_kind(),
67            input.emit_on_window_close(),
68            input.watermark_columns().clone(),
69            input.columns_monotonicity().clone(),
70        );
71        StreamLocalityProvider { base, core }
72    }
73
74    pub fn locality_columns(&self) -> &[usize] {
75        &self.core.locality_columns
76    }
77}
78
79impl PlanTreeNodeUnary<Stream> for StreamLocalityProvider {
80    fn input(&self) -> PlanRef {
81        self.core.input.clone()
82    }
83
84    fn clone_with_input(&self, input: PlanRef) -> Self {
85        let mut core = self.core.clone();
86        core.input = input;
87        Self::new(core)
88    }
89}
90
91impl_plan_tree_node_for_unary! { Stream, StreamLocalityProvider }
92
93impl Distill for StreamLocalityProvider {
94    fn distill<'a>(&self) -> XmlNode<'a> {
95        let vec = self.core.fields_pretty();
96        childless_record("StreamLocalityProvider", vec)
97    }
98}
99
100impl StreamNode for StreamLocalityProvider {
101    fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
102        let state_table = self.build_state_catalog(state);
103        let progress_table = self.build_progress_catalog(state);
104
105        let locality_provider_node = LocalityProviderNode {
106            locality_columns: self.locality_columns().iter().map(|&i| i as u32).collect(),
107            // State table for buffering input data
108            state_table: Some(state_table.to_prost()),
109            // Progress table for tracking backfill progress
110            progress_table: Some(progress_table.to_prost()),
111        };
112
113        PbNodeBody::LocalityProvider(Box::new(locality_provider_node))
114    }
115}
116
117impl ExprRewritable<Stream> for StreamLocalityProvider {
118    fn has_rewritable_expr(&self) -> bool {
119        false
120    }
121
122    fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef {
123        self.clone().into()
124    }
125}
126
127impl ExprVisitable for StreamLocalityProvider {
128    fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {
129        // No expressions to visit
130    }
131}
132
133impl StreamLocalityProvider {
134    /// Build the state table catalog for buffering input data
135    /// Schema: same as input schema (locality handled by primary key ordering)
136    /// Key: `locality_columns` (vnode handled internally by `StateTable`)
137    fn build_state_catalog(&self, state: &mut BuildFragmentGraphState) -> TableCatalog {
138        let mut catalog_builder = TableCatalogBuilder::default();
139        let input = self.input();
140        let input_schema = input.schema();
141
142        // Add all input columns in original order
143        for field in &input_schema.fields {
144            catalog_builder.add_column(field);
145        }
146
147        // Set locality columns as primary key.
148        for locality_col_idx in self.locality_columns() {
149            catalog_builder.add_order_column(*locality_col_idx, OrderType::ascending());
150        }
151        // add streaming key of the input as the rest of the primary key
152        for &key_col_idx in input.expect_stream_key() {
153            catalog_builder.add_order_column(key_col_idx, OrderType::ascending());
154        }
155
156        catalog_builder.set_value_indices((0..input_schema.len()).collect());
157
158        catalog_builder
159            .build(
160                self.input().distribution().dist_column_indices().to_vec(),
161                0,
162            )
163            .with_id(state.gen_table_id_wrapped())
164    }
165
166    /// Build the progress table catalog for tracking backfill progress
167    /// Schema: | vnode | pk(locality columns + input stream keys) | `backfill_finished` | `row_count` |
168    /// Key: | vnode | pk(locality columns + input stream keys) |
169    fn build_progress_catalog(&self, state: &mut BuildFragmentGraphState) -> TableCatalog {
170        let mut catalog_builder = TableCatalogBuilder::default();
171        let input = self.input();
172        let input_schema = input.schema();
173
174        // Add vnode column as primary key
175        catalog_builder.add_column(&Field::with_name(VirtualNode::RW_TYPE, "vnode"));
176        catalog_builder.add_order_column(0, OrderType::ascending());
177
178        // Add locality columns as part of primary key
179        for &locality_col_idx in self.locality_columns() {
180            let field = &input_schema.fields[locality_col_idx];
181            catalog_builder.add_column(field);
182        }
183
184        // Add stream key columns as part of primary key (excluding those already added as locality columns)
185        for &key_col_idx in input.expect_stream_key() {
186            let field = &input_schema.fields[key_col_idx];
187            catalog_builder.add_column(field);
188        }
189
190        // Add backfill_finished column
191        catalog_builder.add_column(&Field::with_name(DataType::Boolean, "backfill_finished"));
192
193        // Add row_count column
194        catalog_builder.add_column(&Field::with_name(DataType::Int64, "row_count"));
195
196        // Set vnode column index and distribution key
197        catalog_builder.set_vnode_col_idx(0);
198        catalog_builder.set_dist_key_in_pk(vec![0]);
199
200        let num_of_columns = catalog_builder.columns().len();
201        catalog_builder.set_value_indices((0..num_of_columns).collect_vec());
202
203        catalog_builder
204            .build(vec![0], 1)
205            .with_id(state.gen_table_id_wrapped())
206    }
207}