risingwave_frontend/planner/
update.rs1use fixedbitset::FixedBitSet;
16use risingwave_common::types::{DataType, Scalar};
17use risingwave_pb::expr::expr_node::Type;
18
19use super::Planner;
20use crate::binder::{BoundUpdate, UpdateProject};
21use crate::error::Result;
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{
25 LogicalPlanRef as PlanRef, LogicalProject, LogicalUpdate, generic,
26};
27use crate::optimizer::property::{Order, RequiredDist};
28use crate::optimizer::{LogicalPlanRoot, PlanRoot};
29
30impl Planner {
31 pub(super) fn plan_update(&mut self, update: BoundUpdate) -> Result<LogicalPlanRoot> {
32 let returning = !update.returning_list.is_empty();
33
34 let scan = self.plan_base_table(&update.table)?;
35 let input = if let Some(expr) = update.selection {
36 self.plan_where(scan, expr)?
37 } else {
38 scan
39 };
40 let old_schema_len = input.schema().len();
41
42 let with_new: PlanRef = {
44 let mut plan = input;
45
46 let mut exprs: Vec<ExprImpl> = plan
47 .schema()
48 .data_types()
49 .into_iter()
50 .enumerate()
51 .map(|(index, data_type)| InputRef::new(index, data_type).into())
52 .collect();
53
54 exprs.extend(update.exprs);
55
56 if exprs.iter().any(|e| e.has_subquery()) {
58 (plan, exprs) = self.substitute_subqueries_in_cross_join_way(plan, exprs)?;
59 }
60
61 LogicalProject::new(plan, exprs).into()
62 };
63
64 let mut olds = Vec::new();
65 let mut news = Vec::new();
66
67 for (i, col) in update.table.table_catalog.columns().iter().enumerate() {
68 if !col.can_dml() {
70 continue;
71 }
72 let data_type = col.data_type();
73
74 let old: ExprImpl = InputRef::new(i, data_type.clone()).into();
75
76 let mut new: ExprImpl = match (update.projects.get(&i)).map(|p| p.offset(old_schema_len)) {
77 Some(UpdateProject::Simple(j)) => InputRef::new(j, data_type.clone()).into(),
78 Some(UpdateProject::Composite(j, field)) => FunctionCall::new_unchecked(
79 Type::Field,
80 vec![
81 InputRef::new(j, with_new.schema().data_types()[j].clone()).into(), Literal::new(Some((field as i32).to_scalar_value()), DataType::Int32)
83 .into(),
84 ],
85 data_type.clone(),
86 )
87 .into(),
88
89 None => old.clone(),
90 };
91 if !col.nullable() {
92 new = FunctionCall::new_unchecked(
93 ExprType::CheckNotNull,
94 vec![
95 new,
96 ExprImpl::literal_varchar(col.name().to_owned()),
97 ExprImpl::literal_varchar(update.table_name.clone()),
98 ],
99 data_type.clone(),
100 )
101 .into();
102 }
103
104 olds.push(old);
105 news.push(new);
106 }
107
108 let mut plan: PlanRef = LogicalUpdate::from(generic::Update::new(
109 with_new,
110 update.table_name.clone(),
111 update.table_id,
112 update.table_version_id,
113 olds,
114 news,
115 returning,
116 ))
117 .into();
118
119 if returning {
120 plan = LogicalProject::create(plan, update.returning_list);
121 }
122
123 let dist = RequiredDist::Any;
125 let mut out_fields = FixedBitSet::with_capacity(plan.schema().len());
126 out_fields.insert_range(..);
127 let out_names = if returning {
128 update.returning_schema.expect("If returning list is not empty, should provide returning schema in BoundDelete.").names()
129 } else {
130 plan.schema().names()
131 };
132
133 let root = PlanRoot::new_with_logical_plan(plan, dist, Order::any(), out_fields, out_names);
134 Ok(root)
135 }
136}