Skip to main content

risingwave_frontend/optimizer/
mod.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::num::NonZeroU32;
16use std::ops::DerefMut;
17use std::sync::Arc;
18
19use risingwave_pb::catalog::PbVectorIndexInfo;
20
21pub mod plan_node;
22
23use plan_node::StreamFilter;
24pub use plan_node::{Explain, LogicalPlanRef, PlanRef};
25
26pub mod property;
27
28mod delta_join_solver;
29mod heuristic_optimizer;
30mod plan_rewriter;
31
32mod plan_visitor;
33
34#[cfg(feature = "datafusion")]
35pub use plan_visitor::DataFusionExecuteCheckerExt;
36pub use plan_visitor::{
37    ExecutionModeDecider, PlanVisitor, RelationCollectorVisitor, SysTableVisitor,
38};
39use risingwave_pb::plan_common::source_refresh_mode::RefreshMode;
40
41pub mod backfill_order_strategy;
42mod logical_optimization;
43mod optimizer_context;
44pub mod plan_expr_rewriter;
45mod plan_expr_visitor;
46mod rule;
47
48use std::collections::{BTreeMap, HashMap};
49use std::marker::PhantomData;
50
51use educe::Educe;
52use fixedbitset::FixedBitSet;
53use itertools::Itertools;
54pub use logical_optimization::*;
55pub use optimizer_context::*;
56use plan_expr_rewriter::ConstEvalRewriter;
57use property::Order;
58use risingwave_common::bail;
59use risingwave_common::catalog::{ColumnCatalog, ColumnDesc, ConflictBehavior, Field, Schema};
60use risingwave_common::types::DataType;
61use risingwave_common::util::column_index_mapping::ColIndexMapping;
62use risingwave_common::util::iter_util::ZipEqDebug;
63use risingwave_connector::WithPropertiesExt;
64use risingwave_connector::sink::catalog::SinkFormatDesc;
65use risingwave_pb::stream_plan::StreamScanType;
66
67use self::heuristic_optimizer::ApplyOrder;
68use self::plan_node::generic::{self, PhysicalPlanRef};
69use self::plan_node::{
70    BatchProject, LogicalProject, LogicalSource, PartitionComputeInfo, StreamDml,
71    StreamMaterialize, StreamProject, StreamRowIdGen, StreamSink, StreamWatermarkFilter,
72    ToStreamContext, stream_enforce_eowc_requirement,
73};
74#[cfg(debug_assertions)]
75use self::plan_visitor::InputRefValidator;
76use self::plan_visitor::{CardinalityVisitor, StreamKeyChecker, has_batch_exchange};
77use self::property::{Cardinality, RequiredDist};
78use self::rule::*;
79use crate::TableCatalog;
80use crate::catalog::table_catalog::TableType;
81use crate::catalog::{DatabaseId, SchemaId};
82use crate::error::{ErrorCode, Result};
83use crate::expr::TimestamptzExprFinder;
84use crate::handler::create_table::{CreateTableInfo, CreateTableProps};
85use crate::optimizer::plan_node::generic::{GenericPlanRef, SourceNodeKind, Union};
86use crate::optimizer::plan_node::{
87    BackfillType, Batch, BatchExchange, BatchPlanNodeType, BatchPlanRef, ConventionMarker,
88    PlanTreeNode, RewriteStreamContext, Stream, StreamExchange, StreamPlanRef, StreamUnion,
89    StreamUpstreamSinkUnion, StreamVectorIndexWrite, ToStream, VisitExprsRecursive,
90};
91use crate::optimizer::plan_visitor::{
92    LocalityProviderCounter, RwTimestampValidator, TemporalJoinValidator,
93};
94use crate::optimizer::property::Distribution;
95use crate::utils::{
96    ColIndexMappingRewriteExt, MV_REFRESH_INTERVAL_SEC_KEY, WithOptionsSecResolved,
97};
98
99/// `PlanRoot` is used to describe a plan. planner will construct a `PlanRoot` with `LogicalNode`.
100/// and required distribution and order. And `PlanRoot` can generate corresponding streaming or
101/// batch plan with optimization. the required Order and Distribution columns might be more than the
102/// output columns. for example:
103/// ```sql
104///    select v1 from t order by id;
105/// ```
106/// the plan will return two columns (id, v1), and the required order column is id. the id
107/// column is required in optimization, but the final generated plan will remove the unnecessary
108/// column in the result.
109#[derive(Educe)]
110#[educe(Debug, Clone)]
111pub struct PlanRoot<P: PlanPhase> {
112    // The current plan node.
113    pub plan: PlanRef<P::Convention>,
114    // The phase of the plan.
115    #[educe(Debug(ignore), Clone(method(PhantomData::clone)))]
116    _phase: PhantomData<P>,
117    required_dist: RequiredDist,
118    required_order: Order,
119    out_fields: FixedBitSet,
120    out_names: Vec<String>,
121}
122
123/// `PlanPhase` is used to track the phase of the `PlanRoot`.
124/// Usually, it begins from `Logical` and ends with `Batch` or `Stream`, unless we want to construct a `PlanRoot` from an intermediate phase.
125/// Typical phase transformation are:
126/// - `Logical` -> `OptimizedLogicalForBatch` -> `Batch`
127/// - `Logical` -> `OptimizedLogicalForStream` -> `Stream`
128pub trait PlanPhase {
129    type Convention: ConventionMarker;
130}
131
132macro_rules! for_all_phase {
133    () => {
134        for_all_phase! {
135            { Logical, $crate::optimizer::plan_node::Logical },
136            { BatchOptimizedLogical, $crate::optimizer::plan_node::Logical },
137            { StreamOptimizedLogical, $crate::optimizer::plan_node::Stream },
138            { Batch, $crate::optimizer::plan_node::Batch },
139            { Stream, $crate::optimizer::plan_node::Stream }
140        }
141    };
142    ($({$phase:ident, $convention:ty}),+ $(,)?) => {
143        $(
144            paste::paste! {
145                pub struct [< PlanPhase$phase >];
146                impl PlanPhase for [< PlanPhase$phase >] {
147                    type Convention = $convention;
148                }
149                pub type [< $phase PlanRoot >] = PlanRoot<[< PlanPhase$phase >]>;
150            }
151        )+
152    }
153}
154
155for_all_phase!();
156
157impl LogicalPlanRoot {
158    pub fn new_with_logical_plan(
159        plan: LogicalPlanRef,
160        required_dist: RequiredDist,
161        required_order: Order,
162        out_fields: FixedBitSet,
163        out_names: Vec<String>,
164    ) -> Self {
165        Self::new_inner(plan, required_dist, required_order, out_fields, out_names)
166    }
167}
168
169impl BatchPlanRoot {
170    pub fn new_with_batch_plan(
171        plan: BatchPlanRef,
172        required_dist: RequiredDist,
173        required_order: Order,
174        out_fields: FixedBitSet,
175        out_names: Vec<String>,
176    ) -> Self {
177        Self::new_inner(plan, required_dist, required_order, out_fields, out_names)
178    }
179}
180
181impl<P: PlanPhase> PlanRoot<P> {
182    fn new_inner(
183        plan: PlanRef<P::Convention>,
184        required_dist: RequiredDist,
185        required_order: Order,
186        out_fields: FixedBitSet,
187        out_names: Vec<String>,
188    ) -> Self {
189        let input_schema = plan.schema();
190        assert_eq!(input_schema.fields().len(), out_fields.len());
191        assert_eq!(out_fields.count_ones(..), out_names.len());
192
193        Self {
194            plan,
195            _phase: PhantomData,
196            required_dist,
197            required_order,
198            out_fields,
199            out_names,
200        }
201    }
202
203    fn into_phase<P2: PlanPhase>(self, plan: PlanRef<P2::Convention>) -> PlanRoot<P2> {
204        PlanRoot {
205            plan,
206            _phase: PhantomData,
207            required_dist: self.required_dist,
208            required_order: self.required_order,
209            out_fields: self.out_fields,
210            out_names: self.out_names,
211        }
212    }
213
214    /// Set customized names of the output fields, used for `CREATE [MATERIALIZED VIEW | SINK] r(a,
215    /// b, ..)`.
216    ///
217    /// If the number of names does not match the number of output fields, an error is returned.
218    pub fn set_out_names(&mut self, out_names: Vec<String>) -> Result<()> {
219        if out_names.len() != self.out_fields.count_ones(..) {
220            Err(ErrorCode::InvalidInputSyntax(
221                "number of column names does not match number of columns".to_owned(),
222            ))?
223        }
224        self.out_names = out_names;
225        Ok(())
226    }
227
228    /// Get the plan root's schema, only including the fields to be output.
229    pub fn schema(&self) -> Schema {
230        // The schema can be derived from the `out_fields` and `out_names`, so we don't maintain it
231        // as a field and always construct one on demand here to keep it in sync.
232        Schema {
233            fields: self
234                .out_fields
235                .ones()
236                .map(|i| self.plan.schema().fields()[i].clone())
237                .zip_eq_debug(&self.out_names)
238                .map(|(field, name)| Field {
239                    name: name.clone(),
240                    ..field
241                })
242                .collect(),
243        }
244    }
245}
246
247impl LogicalPlanRoot {
248    /// Transform the [`PlanRoot`] back to a [`PlanRef`] suitable to be used as a subplan, for
249    /// example as insert source or subquery. This ignores Order but retains post-Order pruning
250    /// (`out_fields`).
251    pub fn into_unordered_subplan(self) -> LogicalPlanRef {
252        if self.out_fields.count_ones(..) == self.out_fields.len() {
253            return self.plan;
254        }
255        LogicalProject::with_out_fields(self.plan, &self.out_fields).into()
256    }
257
258    /// Transform the [`PlanRoot`] wrapped in an array-construction subquery to a [`PlanRef`]
259    /// supported by `ARRAY_AGG`. Similar to the unordered version, this abstracts away internal
260    /// `self.plan` which is further modified by `self.required_order` then `self.out_fields`.
261    pub fn into_array_agg(self) -> Result<LogicalPlanRef> {
262        use generic::Agg;
263        use plan_node::PlanAggCall;
264        use risingwave_common::types::ListValue;
265        use risingwave_expr::aggregate::PbAggKind;
266
267        use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
268        use crate::utils::{Condition, IndexSet};
269
270        let Ok(select_idx) = Itertools::exactly_one(self.out_fields.ones()) else {
271            bail!("subquery must return only one column");
272        };
273        let input_column_type = self.plan.schema().fields()[select_idx].data_type();
274        let return_type = DataType::list(input_column_type.clone());
275        let agg = Agg::new(
276            vec![PlanAggCall {
277                agg_type: PbAggKind::ArrayAgg.into(),
278                return_type: return_type.clone(),
279                inputs: vec![InputRef::new(select_idx, input_column_type.clone())],
280                distinct: false,
281                order_by: self.required_order.column_orders,
282                filter: Condition::true_cond(),
283                direct_args: vec![],
284            }],
285            IndexSet::empty(),
286            self.plan,
287        );
288        Ok(LogicalProject::create(
289            agg.into(),
290            vec![
291                FunctionCall::new(
292                    ExprType::Coalesce,
293                    vec![
294                        InputRef::new(0, return_type).into(),
295                        ExprImpl::literal_list(
296                            ListValue::empty(&input_column_type),
297                            input_column_type,
298                        ),
299                    ],
300                )
301                .unwrap()
302                .into(),
303            ],
304        ))
305    }
306
307    /// Apply logical optimization to the plan for stream.
308    pub fn gen_optimized_logical_plan_for_stream(mut self) -> Result<LogicalPlanRoot> {
309        self.plan = LogicalOptimizer::gen_optimized_logical_plan_for_stream(self.plan.clone())?;
310        Ok(self)
311    }
312
313    /// Apply logical optimization to the plan for batch.
314    pub fn gen_optimized_logical_plan_for_batch(self) -> Result<BatchOptimizedLogicalPlanRoot> {
315        let plan = LogicalOptimizer::gen_optimized_logical_plan_for_batch(self.plan.clone())?;
316        Ok(self.into_phase(plan))
317    }
318
319    pub fn gen_batch_plan(self) -> Result<BatchPlanRoot> {
320        self.gen_optimized_logical_plan_for_batch()?
321            .gen_batch_plan()
322    }
323}
324
325impl BatchOptimizedLogicalPlanRoot {
326    /// Optimize and generate a singleton batch physical plan without exchange nodes.
327    pub fn gen_batch_plan(self) -> Result<BatchPlanRoot> {
328        if TemporalJoinValidator::exist_dangling_temporal_scan(self.plan.clone()) {
329            return Err(ErrorCode::NotSupported(
330                "do not support temporal join for batch queries".to_owned(),
331                "please use temporal join in streaming queries".to_owned(),
332            )
333            .into());
334        }
335
336        let ctx = self.plan.ctx();
337        // Inline session timezone mainly for rewriting now()
338        let mut plan = inline_session_timezone_in_exprs(ctx.clone(), self.plan.clone())?;
339
340        // Const eval of exprs at the last minute, but before `to_batch` to make functional index selection happy.
341        plan = const_eval_exprs(plan)?;
342
343        if ctx.is_explain_trace() {
344            ctx.trace("Const eval exprs:");
345            ctx.trace(plan.explain_to_string());
346        }
347
348        // Convert to physical plan node
349        let mut plan = plan.to_batch_with_order_required(&self.required_order)?;
350        if ctx.is_explain_trace() {
351            ctx.trace("To Batch Plan:");
352            ctx.trace(plan.explain_to_string());
353        }
354
355        plan = plan.optimize_by_rules(&OptimizationStage::<Batch>::new(
356            "Merge BatchProject",
357            vec![BatchProjectMergeRule::create()],
358            ApplyOrder::BottomUp,
359        ))?;
360
361        // Inline session timezone
362        plan = inline_session_timezone_in_exprs(ctx.clone(), plan)?;
363
364        if ctx.is_explain_trace() {
365            ctx.trace("Inline Session Timezone:");
366            ctx.trace(plan.explain_to_string());
367        }
368
369        #[cfg(debug_assertions)]
370        InputRefValidator.validate(plan.clone());
371        assert_eq!(
372            *plan.distribution(),
373            Distribution::Single,
374            "{}",
375            plan.explain_to_string()
376        );
377        assert!(
378            !has_batch_exchange(plan.clone()),
379            "{}",
380            plan.explain_to_string()
381        );
382
383        let ctx = plan.ctx();
384        if ctx.is_explain_trace() {
385            ctx.trace("To Batch Physical Plan:");
386            ctx.trace(plan.explain_to_string());
387        }
388
389        Ok(self.into_phase(plan))
390    }
391
392    #[cfg(feature = "datafusion")]
393    pub fn gen_datafusion_logical_plan(
394        &self,
395    ) -> Result<Arc<datafusion::logical_expr::LogicalPlan>> {
396        use datafusion::logical_expr::{Expr as DFExpr, LogicalPlan, Projection, Sort};
397        use datafusion_common::Column;
398        use plan_visitor::LogicalPlanToDataFusionExt;
399
400        use crate::datafusion::{InputColumns, convert_column_order};
401
402        tracing::debug!(
403            "Converting RisingWave logical plan to DataFusion plan:\nRisingWave Plan: {:?}",
404            self.plan
405        );
406
407        let ctx = self.plan.ctx();
408        // Inline session timezone mainly for rewriting now()
409        let mut plan = inline_session_timezone_in_exprs(ctx, self.plan.clone())?;
410        plan = const_eval_exprs(plan)?;
411
412        let mut df_plan = plan.to_datafusion_logical_plan()?;
413
414        if !self.required_order.is_any() {
415            let input_columns = InputColumns::new(df_plan.schema().as_ref(), plan.schema());
416            let expr = self
417                .required_order
418                .column_orders
419                .iter()
420                .map(|column_order| convert_column_order(column_order, &input_columns))
421                .collect_vec();
422            df_plan = Arc::new(LogicalPlan::Sort(Sort {
423                expr,
424                input: df_plan,
425                fetch: None,
426            }));
427        }
428
429        if self.out_names.len() < df_plan.schema().fields().len() {
430            let df_schema = df_plan.schema().as_ref();
431            let projection_exprs = self
432                .out_fields
433                .ones()
434                .zip_eq_debug(self.out_names.iter())
435                .map(|(i, name)| {
436                    DFExpr::Column(Column::from(df_schema.qualified_field(i))).alias(name)
437                })
438                .collect_vec();
439            df_plan = Arc::new(LogicalPlan::Projection(Projection::try_new(
440                projection_exprs,
441                df_plan,
442            )?));
443        }
444
445        tracing::debug!("Converted DataFusion plan:\nDataFusion Plan: {:?}", df_plan);
446
447        Ok(df_plan)
448    }
449}
450
451impl BatchPlanRoot {
452    /// Optimize and generate a batch query plan for distributed execution.
453    pub fn gen_batch_distributed_plan(mut self) -> Result<BatchPlanRef> {
454        self.required_dist = RequiredDist::single();
455        let mut plan = self.plan;
456
457        // Convert to distributed plan
458        plan = plan.to_distributed_with_required(&self.required_order, &self.required_dist)?;
459
460        let ctx = plan.ctx();
461        if ctx.is_explain_trace() {
462            ctx.trace("To Batch Distributed Plan:");
463            ctx.trace(plan.explain_to_string());
464        }
465        if require_additional_exchange_on_root_in_distributed_mode(plan.clone()) {
466            plan =
467                BatchExchange::new(plan, self.required_order.clone(), Distribution::Single).into();
468        }
469
470        // Add Project if the any position of `self.out_fields` is set to zero.
471        if self.out_fields.count_ones(..) != self.out_fields.len() {
472            plan =
473                BatchProject::new(generic::Project::with_out_fields(plan, &self.out_fields)).into();
474        }
475
476        // Both two phase limit and topn could generate limit on top of the scan, so we push limit here.
477        let plan = plan.optimize_by_rules(&OptimizationStage::new(
478            "Push Limit To Scan",
479            vec![BatchPushLimitToScanRule::create()],
480            ApplyOrder::BottomUp,
481        ))?;
482
483        Ok(plan)
484    }
485
486    /// Optimize and generate a batch query plan for local execution.
487    pub fn gen_batch_local_plan(self) -> Result<BatchPlanRef> {
488        let mut plan = self.plan;
489
490        // Convert to local plan node
491        plan = plan.to_local_with_order_required(&self.required_order)?;
492
493        // We remark that since the `to_local_with_order_required` does not enforce single
494        // distribution, we enforce at the root if needed.
495        let insert_exchange = match plan.distribution() {
496            Distribution::Single => require_additional_exchange_on_root_in_local_mode(plan.clone()),
497            _ => true,
498        };
499        if insert_exchange {
500            plan =
501                BatchExchange::new(plan, self.required_order.clone(), Distribution::Single).into()
502        }
503
504        // Add Project if the any position of `self.out_fields` is set to zero.
505        if self.out_fields.count_ones(..) != self.out_fields.len() {
506            plan =
507                BatchProject::new(generic::Project::with_out_fields(plan, &self.out_fields)).into();
508        }
509
510        let ctx = plan.ctx();
511        if ctx.is_explain_trace() {
512            ctx.trace("To Batch Local Plan:");
513            ctx.trace(plan.explain_to_string());
514        }
515
516        // Both two phase limit and topn could generate limit on top of the scan, so we push limit here.
517        let plan = plan.optimize_by_rules(&OptimizationStage::new(
518            "Push Limit To Scan",
519            vec![BatchPushLimitToScanRule::create()],
520            ApplyOrder::BottomUp,
521        ))?;
522
523        Ok(plan)
524    }
525}
526
527impl LogicalPlanRoot {
528    /// Generate optimized stream plan
529    pub(crate) fn derive_backfill_type(&self, allow_snapshot_backfill: bool) -> BackfillType {
530        if allow_snapshot_backfill && self.should_use_snapshot_backfill() {
531            BackfillType::SnapshotBackfill
532        } else if self.should_use_arrangement_backfill() {
533            BackfillType::ArrangementBackfill
534        } else {
535            BackfillType::Backfill
536        }
537    }
538
539    fn gen_optimized_stream_plan(
540        self,
541        emit_on_window_close: bool,
542        backfill_type: BackfillType,
543    ) -> Result<StreamOptimizedLogicalPlanRoot> {
544        let ctx = self.plan.ctx();
545        let _explain_trace = ctx.is_explain_trace();
546
547        let optimized_plan = self.gen_stream_plan(emit_on_window_close, backfill_type)?;
548
549        let mut plan = optimized_plan
550            .plan
551            .clone()
552            .optimize_by_rules(&OptimizationStage::new(
553                "Merge StreamProject",
554                vec![StreamProjectMergeRule::create()],
555                ApplyOrder::BottomUp,
556            ))?;
557
558        if ctx
559            .session_ctx()
560            .config()
561            .streaming_separate_consecutive_join()
562        {
563            plan = plan.optimize_by_rules(&OptimizationStage::new(
564                "Separate consecutive StreamHashJoin by no-shuffle StreamExchange",
565                vec![SeparateConsecutiveJoinRule::create()],
566                ApplyOrder::BottomUp,
567            ))?;
568        }
569
570        // Add Logstore for Unaligned join
571        // Apply this BEFORE delta join rule, because delta join removes
572        // the join
573        if ctx.session_ctx().config().streaming_enable_unaligned_join() {
574            plan = plan.optimize_by_rules(&OptimizationStage::new(
575                "Add Logstore for Unaligned join",
576                vec![AddLogstoreRule::create()],
577                ApplyOrder::BottomUp,
578            ))?;
579        }
580
581        if ctx.session_ctx().config().streaming_enable_delta_join()
582            && ctx.session_ctx().config().enable_index_selection()
583        {
584            // TODO: make it a logical optimization.
585            // Rewrite joins with index to delta join
586            plan = plan.optimize_by_rules(&OptimizationStage::new(
587                "To IndexDeltaJoin",
588                vec![IndexDeltaJoinRule::create()],
589                ApplyOrder::BottomUp,
590            ))?;
591        }
592        // Inline session timezone
593        plan = inline_session_timezone_in_exprs(ctx.clone(), plan)?;
594
595        if ctx.is_explain_trace() {
596            ctx.trace("Inline session timezone:");
597            ctx.trace(plan.explain_to_string());
598        }
599
600        // Const eval of exprs at the last minute
601        plan = const_eval_exprs(plan)?;
602
603        if ctx.is_explain_trace() {
604            ctx.trace("Const eval exprs:");
605            ctx.trace(plan.explain_to_string());
606        }
607
608        #[cfg(debug_assertions)]
609        InputRefValidator.validate(plan.clone());
610
611        if TemporalJoinValidator::exist_dangling_temporal_scan(plan.clone()) {
612            return Err(ErrorCode::NotSupported(
613                "exist dangling temporal scan".to_owned(),
614                "please check your temporal join syntax e.g. consider removing the right outer join if it is being used.".to_owned(),
615            ).into());
616        }
617
618        if RwTimestampValidator::select_rw_timestamp_in_stream_query(plan.clone()) {
619            return Err(ErrorCode::NotSupported(
620                "selecting `_rw_timestamp` in a streaming query is not allowed".to_owned(),
621                "please run the sql in batch mode or remove the column `_rw_timestamp` from the streaming query".to_owned(),
622            ).into());
623        }
624
625        if LocalityProviderCounter::count(plan.clone()) > 5 {
626            // LocalityProviderCounter is non-zero only when locality backfill is enabled.
627            assert!(ctx.session_ctx().config().enable_locality_backfill());
628            risingwave_common::license::Feature::LocalityBackfill.check_available()?;
629        }
630
631        if ctx.missed_locality_providers() > 1
632            && risingwave_common::license::Feature::LocalityBackfill
633                .check_available()
634                .is_ok()
635        {
636            // missed_locality_providers can only be non-zero when locality backfill is disabled.
637            assert!(!ctx.session_ctx().config().enable_locality_backfill());
638            ctx.warn_to_user(format!(
639                "This streaming job has {} operators that could benefit from locality backfill. \
640                Consider enabling it with `SET enable_locality_backfill = true` for potentially \
641                faster backfill performance, when existing data volume in upstream(s) is large.",
642                ctx.missed_locality_providers()
643            ));
644        }
645
646        Ok(optimized_plan.into_phase(plan))
647    }
648
649    pub(crate) fn require_snapshot_backfill_for_batch_refresh(&self) -> Result<()> {
650        let ctx = self.plan.ctx();
651        let session_ctx = ctx.session_ctx();
652        let snapshot_backfill_enabled = session_ctx
653            .env()
654            .streaming_config()
655            .developer
656            .enable_snapshot_backfill
657            && session_ctx.config().streaming_use_snapshot_backfill();
658        if !snapshot_backfill_enabled {
659            return Err(ErrorCode::NotSupported(
660                "Batch refresh materialized view requires snapshot backfill".to_owned(),
661                format!(
662                    "Please enable snapshot backfill or remove `{}` from the WITH clause.",
663                    MV_REFRESH_INTERVAL_SEC_KEY
664                ),
665            )
666            .into());
667        }
668        if let Some(reason) = self.plan.forbid_snapshot_backfill() {
669            return Err(ErrorCode::NotSupported(
670                format!("Batch refresh materialized view requires snapshot backfill, but {reason}"),
671                "Please rewrite the query to avoid operators that forbid snapshot backfill."
672                    .to_owned(),
673            )
674            .into());
675        }
676        Ok(())
677    }
678
679    /// Generate create index or create materialize view plan.
680    fn gen_stream_plan(
681        self,
682        emit_on_window_close: bool,
683        backfill_type: BackfillType,
684    ) -> Result<StreamOptimizedLogicalPlanRoot> {
685        let ctx = self.plan.ctx();
686        let explain_trace = ctx.is_explain_trace();
687
688        let plan = {
689            {
690                if !ctx
691                    .session_ctx()
692                    .config()
693                    .streaming_allow_jsonb_in_stream_key()
694                    && let Some(err) = StreamKeyChecker.visit(self.plan.clone())
695                {
696                    return Err(ErrorCode::NotSupported(
697                        err,
698                        "Using JSONB columns as part of the join or aggregation keys can severely impair performance. \
699                        If you intend to proceed, force to enable it with: `set rw_streaming_allow_jsonb_in_stream_key to true`".to_owned(),
700                    ).into());
701                }
702                let mut optimized_plan = self.gen_optimized_logical_plan_for_stream()?;
703                let (plan, out_col_change) = {
704                    let (plan, out_col_change) = optimized_plan.plan.logical_rewrite_for_stream(
705                        &mut RewriteStreamContext::new_with_backfill_type(backfill_type),
706                    )?;
707                    if out_col_change.is_injective() {
708                        (plan, out_col_change)
709                    } else {
710                        let mut output_indices = (0..plan.schema().len()).collect_vec();
711                        #[expect(unused_assignments)]
712                        let (mut map, mut target_size) = out_col_change.into_parts();
713
714                        // TODO(st1page): https://github.com/risingwavelabs/risingwave/issues/7234
715                        // assert_eq!(target_size, output_indices.len());
716                        target_size = plan.schema().len();
717                        let mut tar_exists = vec![false; target_size];
718                        for i in map.iter_mut().flatten() {
719                            if tar_exists[*i] {
720                                output_indices.push(*i);
721                                *i = target_size;
722                                target_size += 1;
723                            } else {
724                                tar_exists[*i] = true;
725                            }
726                        }
727                        let plan =
728                            LogicalProject::with_out_col_idx(plan, output_indices.into_iter());
729                        let out_col_change = ColIndexMapping::new(map, target_size);
730                        (plan.into(), out_col_change)
731                    }
732                };
733                if explain_trace {
734                    ctx.trace("Logical Rewrite For Stream:");
735                    ctx.trace(plan.explain_to_string());
736                }
737
738                optimized_plan.required_dist =
739                    out_col_change.rewrite_required_distribution(&optimized_plan.required_dist);
740                optimized_plan.required_order = out_col_change
741                    .rewrite_required_order(&optimized_plan.required_order)
742                    .unwrap();
743                optimized_plan.out_fields =
744                    out_col_change.rewrite_bitset(&optimized_plan.out_fields);
745                let mut plan = plan.to_stream_with_dist_required(
746                    &optimized_plan.required_dist,
747                    &mut ToStreamContext::new_with_backfill_type(
748                        emit_on_window_close,
749                        backfill_type,
750                    ),
751                )?;
752                plan = stream_enforce_eowc_requirement(ctx.clone(), plan, emit_on_window_close)?;
753                optimized_plan.into_phase(plan)
754            }
755        };
756
757        if explain_trace {
758            ctx.trace("To Stream Plan:");
759            // TODO: can be `plan.plan.explain_to_string()`, but should explicitly specify the type due to some limitation of rust compiler
760            ctx.trace(<PlanRef<Stream> as Explain>::explain_to_string(&plan.plan));
761        }
762        Ok(plan)
763    }
764
765    /// Visit the plan root and compute the cardinality.
766    ///
767    /// Panics if not called on a logical plan.
768    fn compute_cardinality(&self) -> Cardinality {
769        CardinalityVisitor.visit(self.plan.clone())
770    }
771
772    /// Optimize and generate a create table plan.
773    pub fn gen_table_plan(
774        self,
775        context: OptimizerContextRef,
776        table_name: String,
777        database_id: DatabaseId,
778        schema_id: SchemaId,
779        CreateTableInfo {
780            columns,
781            pk_column_ids,
782            row_id_index,
783            watermark_descs,
784            source_catalog,
785            version,
786        }: CreateTableInfo,
787        CreateTableProps {
788            definition,
789            append_only,
790            on_conflict,
791            with_version_columns,
792            webhook_info,
793            engine,
794        }: CreateTableProps,
795    ) -> Result<StreamMaterialize> {
796        let backfill_type = self.derive_backfill_type(false);
797        // Snapshot backfill is not allowed for create table
798        let stream_plan = self.gen_optimized_stream_plan(false, backfill_type)?;
799
800        assert!(!pk_column_ids.is_empty() || row_id_index.is_some());
801
802        let pk_column_indices = {
803            let mut id_to_idx = HashMap::new();
804
805            columns.iter().enumerate().for_each(|(idx, c)| {
806                id_to_idx.insert(c.column_id(), idx);
807            });
808            pk_column_ids
809                .iter()
810                .map(|c| id_to_idx.get(c).copied().unwrap()) // pk column id must exist in table columns.
811                .collect_vec()
812        };
813
814        fn inject_project_for_generated_column_if_needed(
815            columns: &[ColumnCatalog],
816            node: StreamPlanRef,
817        ) -> Result<StreamPlanRef> {
818            let exprs = LogicalSource::derive_output_exprs_from_generated_columns(columns)?;
819            if let Some(exprs) = exprs {
820                let logical_project = generic::Project::new(exprs, node);
821                return Ok(StreamProject::new(logical_project).into());
822            }
823            Ok(node)
824        }
825
826        #[derive(PartialEq, Debug, Copy, Clone)]
827        enum PrimaryKeyKind {
828            UserDefinedPrimaryKey,
829            NonAppendOnlyRowIdPk,
830            AppendOnlyRowIdPk,
831        }
832
833        fn inject_dml_node(
834            columns: &[ColumnCatalog],
835            append_only: bool,
836            stream_plan: StreamPlanRef,
837            pk_column_indices: &[usize],
838            kind: PrimaryKeyKind,
839            column_descs: Vec<ColumnDesc>,
840        ) -> Result<StreamPlanRef> {
841            let mut dml_node = StreamDml::new(stream_plan, append_only, column_descs).into();
842
843            // Add generated columns.
844            dml_node = inject_project_for_generated_column_if_needed(columns, dml_node)?;
845
846            dml_node = match kind {
847                PrimaryKeyKind::UserDefinedPrimaryKey | PrimaryKeyKind::NonAppendOnlyRowIdPk => {
848                    RequiredDist::hash_shard(pk_column_indices)
849                        .streaming_enforce_if_not_satisfies(dml_node)?
850                }
851                PrimaryKeyKind::AppendOnlyRowIdPk => {
852                    StreamExchange::new_no_shuffle(dml_node).into()
853                }
854            };
855
856            Ok(dml_node)
857        }
858
859        let kind = if let Some(row_id_index) = row_id_index {
860            assert_eq!(
861                Itertools::exactly_one(pk_column_indices.iter())
862                    .copied()
863                    .unwrap(),
864                row_id_index
865            );
866            if append_only {
867                PrimaryKeyKind::AppendOnlyRowIdPk
868            } else {
869                PrimaryKeyKind::NonAppendOnlyRowIdPk
870            }
871        } else {
872            PrimaryKeyKind::UserDefinedPrimaryKey
873        };
874
875        let column_descs: Vec<ColumnDesc> = columns
876            .iter()
877            .filter(|&c| c.can_dml())
878            .map(|c| c.column_desc.clone())
879            .collect();
880
881        let mut not_null_idxs = vec![];
882        for (idx, column) in column_descs.iter().enumerate() {
883            if !column.nullable {
884                not_null_idxs.push(idx);
885            }
886        }
887
888        let version_column_indices = if !with_version_columns.is_empty() {
889            find_version_column_indices(&columns, with_version_columns)?
890        } else {
891            vec![]
892        };
893
894        let with_external_source = source_catalog.is_some();
895        let (dml_source_node, external_source_node) = if with_external_source {
896            let dummy_source_node = LogicalSource::new(
897                None,
898                columns.clone(),
899                row_id_index,
900                SourceNodeKind::CreateTable,
901                context.clone(),
902                None,
903            )
904            .and_then(|s| s.to_stream(&mut ToStreamContext::new(false)))?;
905            let mut external_source_node = stream_plan.plan;
906            external_source_node =
907                inject_project_for_generated_column_if_needed(&columns, external_source_node)?;
908            external_source_node = match kind {
909                PrimaryKeyKind::UserDefinedPrimaryKey => {
910                    RequiredDist::hash_shard(&pk_column_indices)
911                        .streaming_enforce_if_not_satisfies(external_source_node)?
912                }
913
914                PrimaryKeyKind::NonAppendOnlyRowIdPk | PrimaryKeyKind::AppendOnlyRowIdPk => {
915                    StreamExchange::new_no_shuffle(external_source_node).into()
916                }
917            };
918            (dummy_source_node, Some(external_source_node))
919        } else {
920            (stream_plan.plan, None)
921        };
922
923        let dml_node = inject_dml_node(
924            &columns,
925            append_only,
926            dml_source_node,
927            &pk_column_indices,
928            kind,
929            column_descs,
930        )?;
931
932        let dists = external_source_node
933            .iter()
934            .map(|input| input.distribution())
935            .chain([dml_node.distribution()])
936            .unique()
937            .collect_vec();
938
939        let dist = match &dists[..] {
940            &[Distribution::SomeShard, Distribution::HashShard(_)]
941            | &[Distribution::HashShard(_), Distribution::SomeShard] => Distribution::SomeShard,
942            &[dist @ Distribution::SomeShard] | &[dist @ Distribution::HashShard(_)] => {
943                dist.clone()
944            }
945            _ => {
946                unreachable!()
947            }
948        };
949
950        let generated_column_exprs =
951            LogicalSource::derive_output_exprs_from_generated_columns(&columns)?;
952        let upstream_sink_union = StreamUpstreamSinkUnion::new(
953            context.clone(),
954            dml_node.schema(),
955            dml_node.stream_key(),
956            dist.clone(), // should always be the same as dist of `Union`
957            append_only,
958            row_id_index.is_none(),
959            generated_column_exprs,
960        );
961
962        let union_inputs = external_source_node
963            .into_iter()
964            .chain([dml_node, upstream_sink_union.into()])
965            .collect_vec();
966
967        let mut stream_plan: StreamPlanRef = StreamUnion::new_with_dist(
968            Union {
969                all: true,
970                inputs: union_inputs,
971                source_col: None,
972            },
973            dist,
974        )
975        .into();
976
977        let ttl_watermark_indices = watermark_descs
978            .iter()
979            .filter(|d| d.with_ttl)
980            .map(|d| d.watermark_idx as usize)
981            .collect_vec();
982
983        let add_row_id_gen = |stream_plan: StreamPlanRef, row_id_index| match kind {
984            PrimaryKeyKind::UserDefinedPrimaryKey => {
985                unreachable!()
986            }
987            PrimaryKeyKind::NonAppendOnlyRowIdPk | PrimaryKeyKind::AppendOnlyRowIdPk => {
988                StreamRowIdGen::new_with_dist(
989                    stream_plan,
990                    row_id_index,
991                    Distribution::HashShard(vec![row_id_index]),
992                )
993                .into()
994            }
995        };
996
997        // Add RowIDGen before WatermarkFilter, so filtering always sees a valid row-id key.
998        if let Some(row_id_index) = row_id_index {
999            stream_plan = add_row_id_gen(stream_plan, row_id_index);
1000        }
1001
1002        // Add WatermarkFilter node.
1003        if !watermark_descs.is_empty() {
1004            stream_plan = StreamWatermarkFilter::new(stream_plan, watermark_descs).into();
1005        }
1006
1007        let conflict_behavior = on_conflict.to_behavior(append_only, row_id_index.is_some())?;
1008
1009        if let ConflictBehavior::IgnoreConflict = conflict_behavior
1010            && !version_column_indices.is_empty()
1011        {
1012            Err(ErrorCode::InvalidParameterValue(
1013                "The with version column syntax cannot be used with the ignore behavior of on conflict".to_owned(),
1014            ))?
1015        }
1016
1017        let retention_seconds = context.with_options().retention_seconds();
1018
1019        let table_required_dist = {
1020            let mut bitset = FixedBitSet::with_capacity(columns.len());
1021            for idx in &pk_column_indices {
1022                bitset.insert(*idx);
1023            }
1024            RequiredDist::ShardByKey(bitset)
1025        };
1026
1027        let mut stream_plan = inline_session_timezone_in_exprs(context, stream_plan)?;
1028
1029        if !not_null_idxs.is_empty() {
1030            stream_plan =
1031                StreamFilter::filter_out_any_null_rows(stream_plan.clone(), &not_null_idxs);
1032        }
1033
1034        // Determine if the table should be refreshable based on the connector type
1035        let refreshable = source_catalog
1036            .as_ref()
1037            .map(|catalog| {
1038                catalog.with_properties.is_batch_connector() || {
1039                    matches!(
1040                        catalog
1041                            .refresh_mode
1042                            .as_ref()
1043                            .map(|refresh_mode| refresh_mode.refresh_mode),
1044                        Some(Some(RefreshMode::FullReload(_)))
1045                    )
1046                }
1047            })
1048            .unwrap_or(false);
1049
1050        // Validate that refreshable tables have a user-defined primary key (i.e., does not have rowid)
1051        if refreshable && row_id_index.is_some() {
1052            return Err(crate::error::ErrorCode::BindError(
1053                "Refreshable tables must have a PRIMARY KEY. Please define a primary key for the table."
1054                    .to_owned(),
1055            )
1056            .into());
1057        }
1058
1059        StreamMaterialize::create_for_table(
1060            stream_plan,
1061            table_name,
1062            database_id,
1063            schema_id,
1064            table_required_dist,
1065            Order::any(),
1066            columns,
1067            definition,
1068            conflict_behavior,
1069            version_column_indices,
1070            pk_column_indices,
1071            ttl_watermark_indices,
1072            row_id_index,
1073            version,
1074            retention_seconds,
1075            webhook_info,
1076            engine,
1077            refreshable,
1078        )
1079    }
1080
1081    /// Optimize and generate a create materialized view plan.
1082    pub fn gen_materialize_plan(
1083        self,
1084        database_id: DatabaseId,
1085        schema_id: SchemaId,
1086        mv_name: String,
1087        definition: String,
1088        emit_on_window_close: bool,
1089        backfill_type: BackfillType,
1090    ) -> Result<StreamMaterialize> {
1091        let cardinality = self.compute_cardinality();
1092        let stream_plan = self.gen_optimized_stream_plan(emit_on_window_close, backfill_type)?;
1093        StreamMaterialize::create(
1094            stream_plan,
1095            mv_name,
1096            database_id,
1097            schema_id,
1098            definition,
1099            TableType::MaterializedView,
1100            cardinality,
1101            None,
1102        )
1103    }
1104
1105    /// Optimize and generate a create index plan.
1106    pub fn gen_index_plan(
1107        self,
1108        index_name: String,
1109        database_id: DatabaseId,
1110        schema_id: SchemaId,
1111        definition: String,
1112        retention_seconds: Option<NonZeroU32>,
1113    ) -> Result<StreamMaterialize> {
1114        let cardinality = self.compute_cardinality();
1115        let backfill_type = self.derive_backfill_type(false);
1116        let stream_plan = self.gen_optimized_stream_plan(false, backfill_type)?;
1117
1118        StreamMaterialize::create(
1119            stream_plan,
1120            index_name,
1121            database_id,
1122            schema_id,
1123            definition,
1124            TableType::Index,
1125            cardinality,
1126            retention_seconds,
1127        )
1128    }
1129
1130    pub fn gen_vector_index_plan(
1131        self,
1132        index_name: String,
1133        database_id: DatabaseId,
1134        schema_id: SchemaId,
1135        definition: String,
1136        retention_seconds: Option<NonZeroU32>,
1137        vector_index_info: PbVectorIndexInfo,
1138    ) -> Result<StreamVectorIndexWrite> {
1139        let cardinality = self.compute_cardinality();
1140        let backfill_type = self.derive_backfill_type(false);
1141        let stream_plan = self.gen_optimized_stream_plan(false, backfill_type)?;
1142
1143        StreamVectorIndexWrite::create(
1144            stream_plan,
1145            index_name,
1146            database_id,
1147            schema_id,
1148            definition,
1149            cardinality,
1150            retention_seconds,
1151            vector_index_info,
1152        )
1153    }
1154
1155    /// Optimize and generate a create sink plan.
1156    #[expect(clippy::too_many_arguments)]
1157    pub fn gen_sink_plan(
1158        self,
1159        sink_name: String,
1160        definition: String,
1161        properties: WithOptionsSecResolved,
1162        emit_on_window_close: bool,
1163        db_name: String,
1164        sink_from_table_name: String,
1165        format_desc: Option<SinkFormatDesc>,
1166        without_backfill: bool,
1167        target_table: Option<Arc<TableCatalog>>,
1168        partition_info: Option<PartitionComputeInfo>,
1169        user_specified_columns: bool,
1170        auto_refresh_schema_from_table: Option<Arc<TableCatalog>>,
1171        allow_snapshot_backfill: bool,
1172    ) -> Result<StreamSink> {
1173        let backfill_type = if without_backfill {
1174            BackfillType::UpstreamOnly
1175        } else if allow_snapshot_backfill
1176            && self.should_use_snapshot_backfill()
1177            && {
1178                if auto_refresh_schema_from_table.is_some() {
1179                    self.plan.ctx().session_ctx().notice_to_user("Auto schema change only support for ArrangementBackfill. Switched to use ArrangementBackfill");
1180                    false
1181                } else {
1182                    true
1183                }
1184            }
1185        {
1186            assert!(
1187                target_table.is_none(),
1188                "should not allow snapshot backfill for sink-into-table"
1189            );
1190            // Snapshot backfill on sink-into-table is not allowed
1191            BackfillType::SnapshotBackfill
1192        } else if self.should_use_arrangement_backfill() {
1193            BackfillType::ArrangementBackfill
1194        } else {
1195            BackfillType::Backfill
1196        };
1197        if auto_refresh_schema_from_table.is_some()
1198            && backfill_type != BackfillType::ArrangementBackfill
1199        {
1200            return Err(ErrorCode::InvalidInputSyntax(format!(
1201                "auto schema change only support for ArrangementBackfill, but got: {:?}",
1202                backfill_type
1203            ))
1204            .into());
1205        }
1206        let stream_plan = self.gen_optimized_stream_plan(emit_on_window_close, backfill_type)?;
1207        let target_columns_to_plan_mapping = target_table.as_ref().map(|t| {
1208            let columns = t.columns_without_rw_timestamp();
1209            stream_plan.target_columns_to_plan_mapping(&columns, user_specified_columns)
1210        });
1211
1212        StreamSink::create(
1213            stream_plan,
1214            sink_name,
1215            db_name,
1216            sink_from_table_name,
1217            target_table,
1218            target_columns_to_plan_mapping,
1219            definition,
1220            properties,
1221            format_desc,
1222            partition_info,
1223            auto_refresh_schema_from_table,
1224        )
1225    }
1226
1227    pub fn should_use_arrangement_backfill(&self) -> bool {
1228        let ctx = self.plan.ctx();
1229        let session_ctx = ctx.session_ctx();
1230        let arrangement_backfill_enabled = session_ctx
1231            .env()
1232            .streaming_config()
1233            .developer
1234            .enable_arrangement_backfill;
1235        arrangement_backfill_enabled && session_ctx.config().streaming_use_arrangement_backfill()
1236    }
1237
1238    pub fn should_use_snapshot_backfill(&self) -> bool {
1239        let ctx = self.plan.ctx();
1240        let session_ctx = ctx.session_ctx();
1241        let use_snapshot_backfill = session_ctx
1242            .env()
1243            .streaming_config()
1244            .developer
1245            .enable_snapshot_backfill
1246            && session_ctx.config().streaming_use_snapshot_backfill();
1247        if use_snapshot_backfill {
1248            if let Some(warning_msg) = self.plan.forbid_snapshot_backfill() {
1249                self.plan.ctx().session_ctx().notice_to_user(warning_msg);
1250                false
1251            } else {
1252                true
1253            }
1254        } else {
1255            false
1256        }
1257    }
1258}
1259
1260impl<P: PlanPhase> PlanRoot<P> {
1261    /// used when the plan has a target relation such as DML and sink into table, return the mapping from table's columns to the plan's schema
1262    pub fn target_columns_to_plan_mapping(
1263        &self,
1264        tar_cols: &[ColumnCatalog],
1265        user_specified_columns: bool,
1266    ) -> Vec<Option<usize>> {
1267        #[expect(clippy::disallowed_methods)]
1268        let visible_cols: Vec<(usize, String)> = self
1269            .out_fields
1270            .ones()
1271            .zip_eq(self.out_names.iter().cloned())
1272            .collect_vec();
1273
1274        let visible_col_idxes = visible_cols.iter().map(|(i, _)| *i).collect_vec();
1275        let visible_col_idxes_by_name = visible_cols
1276            .iter()
1277            .map(|(i, name)| (name.as_ref(), *i))
1278            .collect::<BTreeMap<_, _>>();
1279
1280        tar_cols
1281            .iter()
1282            .enumerate()
1283            .filter(|(_, tar_col)| tar_col.can_dml())
1284            .map(|(tar_i, tar_col)| {
1285                if user_specified_columns {
1286                    visible_col_idxes_by_name.get(tar_col.name()).cloned()
1287                } else {
1288                    (tar_i < visible_col_idxes.len()).then(|| visible_cols[tar_i].0)
1289                }
1290            })
1291            .collect()
1292    }
1293}
1294
1295fn find_version_column_indices(
1296    column_catalog: &Vec<ColumnCatalog>,
1297    version_column_names: Vec<String>,
1298) -> Result<Vec<usize>> {
1299    let mut indices = Vec::new();
1300    for version_column_name in version_column_names {
1301        let mut found = false;
1302        for (index, column) in column_catalog.iter().enumerate() {
1303            if column.column_desc.name == version_column_name {
1304                if let &DataType::Jsonb
1305                | &DataType::List(_)
1306                | &DataType::Struct(_)
1307                | &DataType::Bytea
1308                | &DataType::Boolean = column.data_type()
1309                {
1310                    return Err(ErrorCode::InvalidInputSyntax(format!(
1311                        "Version column {} must be of a comparable data type",
1312                        version_column_name
1313                    ))
1314                    .into());
1315                }
1316                indices.push(index);
1317                found = true;
1318                break;
1319            }
1320        }
1321        if !found {
1322            return Err(ErrorCode::InvalidInputSyntax(format!(
1323                "Version column {} not found",
1324                version_column_name
1325            ))
1326            .into());
1327        }
1328    }
1329    Ok(indices)
1330}
1331
1332fn const_eval_exprs<C: ConventionMarker>(plan: PlanRef<C>) -> Result<PlanRef<C>> {
1333    let mut const_eval_rewriter = ConstEvalRewriter { error: None };
1334
1335    let plan = plan.rewrite_exprs_recursive(&mut const_eval_rewriter);
1336    if let Some(error) = const_eval_rewriter.error {
1337        return Err(error);
1338    }
1339    Ok(plan)
1340}
1341
1342fn inline_session_timezone_in_exprs<C: ConventionMarker>(
1343    ctx: OptimizerContextRef,
1344    plan: PlanRef<C>,
1345) -> Result<PlanRef<C>> {
1346    let mut v = TimestamptzExprFinder::default();
1347    plan.visit_exprs_recursive(&mut v);
1348    if v.has() {
1349        Ok(plan.rewrite_exprs_recursive(ctx.session_timezone().deref_mut()))
1350    } else {
1351        Ok(plan)
1352    }
1353}
1354
1355fn exist_and_no_exchange_before(
1356    plan: &BatchPlanRef,
1357    is_candidate: fn(&BatchPlanRef) -> bool,
1358) -> bool {
1359    if plan.node_type() == BatchPlanNodeType::BatchExchange {
1360        return false;
1361    }
1362    is_candidate(plan)
1363        || plan
1364            .inputs()
1365            .iter()
1366            .any(|input| exist_and_no_exchange_before(input, is_candidate))
1367}
1368
1369impl BatchPlanRef {
1370    fn is_user_table_scan(&self) -> bool {
1371        self.node_type() == BatchPlanNodeType::BatchSeqScan
1372            || self.node_type() == BatchPlanNodeType::BatchLogSeqScan
1373            || self.node_type() == BatchPlanNodeType::BatchVectorSearch
1374    }
1375
1376    fn is_source_scan(&self) -> bool {
1377        self.node_type() == BatchPlanNodeType::BatchSource
1378            || self.node_type() == BatchPlanNodeType::BatchKafkaScan
1379            || self.node_type() == BatchPlanNodeType::BatchIcebergScan
1380    }
1381
1382    fn is_insert(&self) -> bool {
1383        self.node_type() == BatchPlanNodeType::BatchInsert
1384    }
1385
1386    fn is_update(&self) -> bool {
1387        self.node_type() == BatchPlanNodeType::BatchUpdate
1388    }
1389
1390    fn is_delete(&self) -> bool {
1391        self.node_type() == BatchPlanNodeType::BatchDelete
1392    }
1393}
1394
1395/// As we always run the root stage locally, for some plan in root stage which need to execute in
1396/// compute node we insert an additional exhchange before it to avoid to include it in the root
1397/// stage.
1398///
1399/// Returns `true` if we must insert an additional exchange to ensure this.
1400fn require_additional_exchange_on_root_in_distributed_mode(plan: BatchPlanRef) -> bool {
1401    assert_eq!(plan.distribution(), &Distribution::Single);
1402    exist_and_no_exchange_before(&plan, |plan| {
1403        plan.is_user_table_scan()
1404            || plan.is_source_scan()
1405            || plan.is_insert()
1406            || plan.is_update()
1407            || plan.is_delete()
1408    })
1409}
1410
1411/// The purpose is same as `require_additional_exchange_on_root_in_distributed_mode`. We separate
1412/// them for the different requirement of plan node in different execute mode.
1413fn require_additional_exchange_on_root_in_local_mode(plan: BatchPlanRef) -> bool {
1414    assert_eq!(plan.distribution(), &Distribution::Single);
1415    exist_and_no_exchange_before(&plan, |plan| {
1416        plan.is_user_table_scan() || plan.is_source_scan() || plan.is_insert()
1417    })
1418}
1419
1420#[cfg(test)]
1421mod tests {
1422    use super::*;
1423    use crate::optimizer::plan_node::LogicalValues;
1424
1425    #[tokio::test]
1426    async fn test_as_subplan() {
1427        let ctx = OptimizerContext::mock();
1428        let values = LogicalValues::new(
1429            vec![],
1430            Schema::new(vec![
1431                Field::with_name(DataType::Int32, "v1"),
1432                Field::with_name(DataType::Varchar, "v2"),
1433            ]),
1434            ctx,
1435        )
1436        .into();
1437        let out_fields = FixedBitSet::with_capacity_and_blocks(2, [1]);
1438        let out_names = vec!["v1".into()];
1439        let root = PlanRoot::new_with_logical_plan(
1440            values,
1441            RequiredDist::Any,
1442            Order::any(),
1443            out_fields,
1444            out_names,
1445        );
1446        let subplan = root.into_unordered_subplan();
1447        assert_eq!(
1448            subplan.schema(),
1449            &Schema::new(vec![Field::with_name(DataType::Int32, "v1")])
1450        );
1451    }
1452}