risingwave_frontend/optimizer/plan_node/
stream_locality_provider.rs1use itertools::Itertools;
16use pretty_xmlish::XmlNode;
17use risingwave_common::catalog::Field;
18use risingwave_common::hash::VirtualNode;
19use risingwave_common::types::DataType;
20use risingwave_common::util::sort_util::OrderType;
21use risingwave_pb::stream_plan::LocalityProviderNode;
22use risingwave_pb::stream_plan::stream_node::PbNodeBody;
23
24use super::stream::prelude::*;
25use super::utils::{Distill, TableCatalogBuilder, childless_record};
26use super::{ExprRewritable, PlanTreeNodeUnary, StreamNode, StreamPlanRef as PlanRef, generic};
27use crate::TableCatalog;
28use crate::catalog::TableId;
29use crate::expr::{ExprRewriter, ExprVisitor};
30use crate::optimizer::plan_node::PlanBase;
31use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
32use crate::optimizer::property::Distribution;
33use crate::stream_fragmenter::BuildFragmentGraphState;
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct StreamLocalityProvider {
38 pub base: PlanBase<Stream>,
39 core: generic::LocalityProvider<PlanRef>,
40}
41
42impl StreamLocalityProvider {
43 pub fn new(core: generic::LocalityProvider<PlanRef>) -> Self {
44 let input = core.input.clone();
45
46 let dist = match input.distribution() {
47 Distribution::HashShard(keys) => {
48 Distribution::UpstreamHashShard(keys.clone(), TableId::placeholder())
53 }
54 Distribution::UpstreamHashShard(keys, table_id) => {
55 Distribution::UpstreamHashShard(keys.clone(), *table_id)
56 }
57 _ => {
58 panic!("LocalityProvider input must be hash-distributed");
59 }
60 };
61
62 let base = PlanBase::new_stream_with_core(
64 &core,
65 dist,
66 input.stream_kind(),
67 input.emit_on_window_close(),
68 input.watermark_columns().clone(),
69 input.columns_monotonicity().clone(),
70 );
71 StreamLocalityProvider { base, core }
72 }
73
74 pub fn locality_columns(&self) -> &[usize] {
75 &self.core.locality_columns
76 }
77}
78
79impl PlanTreeNodeUnary<Stream> for StreamLocalityProvider {
80 fn input(&self) -> PlanRef {
81 self.core.input.clone()
82 }
83
84 fn clone_with_input(&self, input: PlanRef) -> Self {
85 let mut core = self.core.clone();
86 core.input = input;
87 Self::new(core)
88 }
89}
90
91impl_plan_tree_node_for_unary! { Stream, StreamLocalityProvider }
92
93impl Distill for StreamLocalityProvider {
94 fn distill<'a>(&self) -> XmlNode<'a> {
95 let vec = self.core.fields_pretty();
96 childless_record("StreamLocalityProvider", vec)
97 }
98}
99
100impl StreamNode for StreamLocalityProvider {
101 fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
102 let state_table = self.build_state_catalog(state);
103 let progress_table = self.build_progress_catalog(state);
104
105 let locality_provider_node = LocalityProviderNode {
106 locality_columns: self.locality_columns().iter().map(|&i| i as u32).collect(),
107 state_table: Some(state_table.to_prost()),
109 progress_table: Some(progress_table.to_prost()),
111 };
112
113 PbNodeBody::LocalityProvider(Box::new(locality_provider_node))
114 }
115}
116
117impl ExprRewritable<Stream> for StreamLocalityProvider {
118 fn has_rewritable_expr(&self) -> bool {
119 false
120 }
121
122 fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef {
123 self.clone().into()
124 }
125}
126
127impl ExprVisitable for StreamLocalityProvider {
128 fn visit_exprs(&self, _v: &mut dyn ExprVisitor) {
129 }
131}
132
133impl StreamLocalityProvider {
134 fn build_state_catalog(&self, state: &mut BuildFragmentGraphState) -> TableCatalog {
138 let mut catalog_builder = TableCatalogBuilder::default();
139 let input = self.input();
140 let input_schema = input.schema();
141
142 for field in &input_schema.fields {
144 catalog_builder.add_column(field);
145 }
146
147 for locality_col_idx in self.locality_columns() {
149 catalog_builder.add_order_column(*locality_col_idx, OrderType::ascending());
150 }
151 for &key_col_idx in input.expect_stream_key() {
153 catalog_builder.add_order_column(key_col_idx, OrderType::ascending());
154 }
155
156 catalog_builder.set_value_indices((0..input_schema.len()).collect());
157
158 catalog_builder
159 .build(
160 self.input().distribution().dist_column_indices().to_vec(),
161 0,
162 )
163 .with_id(state.gen_table_id_wrapped())
164 }
165
166 fn build_progress_catalog(&self, state: &mut BuildFragmentGraphState) -> TableCatalog {
170 let mut catalog_builder = TableCatalogBuilder::default();
171 let input = self.input();
172 let input_schema = input.schema();
173
174 catalog_builder.add_column(&Field::with_name(VirtualNode::RW_TYPE, "vnode"));
176 catalog_builder.add_order_column(0, OrderType::ascending());
177
178 for &locality_col_idx in self.locality_columns() {
180 let field = &input_schema.fields[locality_col_idx];
181 catalog_builder.add_column(field);
182 }
183
184 for &key_col_idx in input.expect_stream_key() {
186 let field = &input_schema.fields[key_col_idx];
187 catalog_builder.add_column(field);
188 }
189
190 catalog_builder.add_column(&Field::with_name(DataType::Boolean, "backfill_finished"));
192
193 catalog_builder.add_column(&Field::with_name(DataType::Int64, "row_count"));
195
196 catalog_builder.set_vnode_col_idx(0);
198 catalog_builder.set_dist_key_in_pk(vec![0]);
199
200 let num_of_columns = catalog_builder.columns().len();
201 catalog_builder.set_value_indices((0..num_of_columns).collect_vec());
202
203 catalog_builder
204 .build(vec![0], 1)
205 .with_id(state.gen_table_id_wrapped())
206 }
207}