risingwave_frontend/optimizer/
optimizer_context.rs1use core::fmt::Formatter;
16use std::cell::{Cell, RefCell, RefMut};
17use std::collections::HashMap;
18use std::marker::PhantomData;
19use std::rc::Rc;
20use std::sync::Arc;
21
22use risingwave_sqlparser::ast::{ExplainFormat, ExplainOptions, ExplainType};
23
24use super::property::WatermarkGroupId;
25use crate::Explain;
26use crate::binder::ShareId;
27use crate::expr::{CorrelatedId, SessionTimezone};
28use crate::handler::HandlerArgs;
29use crate::optimizer::LogicalPlanRef;
30use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, PlanNodeId};
31use crate::session::SessionImpl;
32use crate::utils::{OverwriteOptions, WithOptions};
33
34const RESERVED_ID_NUM: u16 = 10000;
35
36type PhantomUnsend = PhantomData<Rc<()>>;
37
38pub struct OptimizerContext {
39 session_ctx: Arc<SessionImpl>,
40 sql: Arc<str>,
42 normalized_sql: String,
44 explain_options: ExplainOptions,
46 optimizer_trace: RefCell<Vec<String>>,
48 logical_explain: RefCell<Option<String>>,
50 with_options: WithOptions,
52 session_timezone: RefCell<SessionTimezone>,
54 total_rule_applied: RefCell<usize>,
56 overwrite_options: OverwriteOptions,
59 rcte_cache: RefCell<HashMap<ShareId, PlanRef>>,
62
63 last_plan_node_id: Cell<i32>,
65 last_correlated_id: Cell<u32>,
67 last_expr_display_id: Cell<usize>,
69 last_watermark_group_id: Cell<u32>,
71
72 _phantom: PhantomUnsend,
73}
74
75pub(in crate::optimizer) struct LastAssignedIds {
76 last_plan_node_id: i32,
77 last_correlated_id: u32,
78 last_expr_display_id: usize,
79 last_watermark_group_id: u32,
80}
81
82pub type OptimizerContextRef = Rc<OptimizerContext>;
83
84impl OptimizerContext {
85 pub fn from_handler_args(handler_args: HandlerArgs) -> Self {
88 Self::new(handler_args, ExplainOptions::default())
89 }
90
91 pub fn new(mut handler_args: HandlerArgs, explain_options: ExplainOptions) -> Self {
93 let session_timezone = RefCell::new(SessionTimezone::new(
94 handler_args.session.config().timezone().to_owned(),
95 ));
96 let overwrite_options = OverwriteOptions::new(&mut handler_args);
97 Self {
98 session_ctx: handler_args.session,
99 sql: handler_args.sql,
100 normalized_sql: handler_args.normalized_sql,
101 explain_options,
102 optimizer_trace: RefCell::new(vec![]),
103 logical_explain: RefCell::new(None),
104 with_options: handler_args.with_options,
105 session_timezone,
106 total_rule_applied: RefCell::new(0),
107 overwrite_options,
108 rcte_cache: RefCell::new(HashMap::new()),
109
110 last_plan_node_id: Cell::new(RESERVED_ID_NUM.into()),
111 last_correlated_id: Cell::new(0),
112 last_expr_display_id: Cell::new(RESERVED_ID_NUM.into()),
113 last_watermark_group_id: Cell::new(RESERVED_ID_NUM.into()),
114
115 _phantom: Default::default(),
116 }
117 }
118
119 #[cfg(test)]
121 #[expect(clippy::unused_async)]
122 pub async fn mock() -> OptimizerContextRef {
123 Self {
124 session_ctx: Arc::new(SessionImpl::mock()),
125 sql: Arc::from(""),
126 normalized_sql: "".to_owned(),
127 explain_options: ExplainOptions::default(),
128 optimizer_trace: RefCell::new(vec![]),
129 logical_explain: RefCell::new(None),
130 with_options: Default::default(),
131 session_timezone: RefCell::new(SessionTimezone::new("UTC".into())),
132 total_rule_applied: RefCell::new(0),
133 overwrite_options: OverwriteOptions::default(),
134 rcte_cache: RefCell::new(HashMap::new()),
135
136 last_plan_node_id: Cell::new(0),
137 last_correlated_id: Cell::new(0),
138 last_expr_display_id: Cell::new(0),
139 last_watermark_group_id: Cell::new(0),
140
141 _phantom: Default::default(),
142 }
143 .into()
144 }
145
146 pub fn next_plan_node_id(&self) -> PlanNodeId {
147 self.last_plan_node_id.update(|id| id + 1);
148 PlanNodeId(self.last_plan_node_id.get())
149 }
150
151 pub fn next_correlated_id(&self) -> CorrelatedId {
152 self.last_correlated_id.update(|id| id + 1);
153 self.last_correlated_id.get()
154 }
155
156 pub fn next_expr_display_id(&self) -> usize {
157 self.last_expr_display_id.update(|id| id + 1);
158 self.last_expr_display_id.get()
159 }
160
161 pub fn next_watermark_group_id(&self) -> WatermarkGroupId {
162 self.last_watermark_group_id.update(|id| id + 1);
163 self.last_watermark_group_id.get()
164 }
165
166 pub(in crate::optimizer) fn backup_elem_ids(&self) -> LastAssignedIds {
167 LastAssignedIds {
168 last_plan_node_id: self.last_plan_node_id.get(),
169 last_correlated_id: self.last_correlated_id.get(),
170 last_expr_display_id: self.last_expr_display_id.get(),
171 last_watermark_group_id: self.last_watermark_group_id.get(),
172 }
173 }
174
175 pub(in crate::optimizer) fn reset_elem_ids(&self) {
177 self.last_plan_node_id.set(0);
178 self.last_correlated_id.set(0);
179 self.last_expr_display_id.set(0);
180 self.last_watermark_group_id.set(0);
181 }
182
183 pub(in crate::optimizer) fn restore_elem_ids(&self, backup: LastAssignedIds) {
184 self.last_plan_node_id.set(backup.last_plan_node_id);
185 self.last_correlated_id.set(backup.last_correlated_id);
186 self.last_expr_display_id.set(backup.last_expr_display_id);
187 self.last_watermark_group_id
188 .set(backup.last_watermark_group_id);
189 }
190
191 pub fn add_rule_applied(&self, num: usize) {
192 *self.total_rule_applied.borrow_mut() += num;
193 }
194
195 pub fn total_rule_applied(&self) -> usize {
196 *self.total_rule_applied.borrow()
197 }
198
199 pub fn is_explain_verbose(&self) -> bool {
200 self.explain_options.verbose
201 }
202
203 pub fn is_explain_trace(&self) -> bool {
204 self.explain_options.trace
205 }
206
207 fn is_explain_logical(&self) -> bool {
208 self.explain_options.explain_type == ExplainType::Logical
209 }
210
211 pub fn trace(&self, str: impl Into<String>) {
212 if self.is_explain_logical() && self.logical_explain.borrow().is_some() {
214 return;
215 }
216 let mut optimizer_trace = self.optimizer_trace.borrow_mut();
217 let string = str.into();
218 tracing::info!(target: "explain_trace", "\n{}", string);
219 optimizer_trace.push(string);
220 optimizer_trace.push("\n".to_owned());
221 }
222
223 pub fn warn_to_user(&self, str: impl Into<String>) {
224 self.session_ctx().notice_to_user(str);
225 }
226
227 fn explain_plan_impl(&self, plan: &impl Explain) -> String {
228 match self.explain_options.explain_format {
229 ExplainFormat::Text => plan.explain_to_string(),
230 ExplainFormat::Json => plan.explain_to_json(),
231 ExplainFormat::Xml => plan.explain_to_xml(),
232 ExplainFormat::Yaml => plan.explain_to_yaml(),
233 ExplainFormat::Dot => plan.explain_to_dot(),
234 }
235 }
236
237 pub fn may_store_explain_logical(&self, plan: &LogicalPlanRef) {
238 if self.is_explain_logical() {
239 let str = self.explain_plan_impl(plan);
240 *self.logical_explain.borrow_mut() = Some(str);
241 }
242 }
243
244 pub fn take_logical(&self) -> Option<String> {
245 self.logical_explain.borrow_mut().take()
246 }
247
248 pub fn take_trace(&self) -> Vec<String> {
249 self.optimizer_trace.borrow_mut().drain(..).collect()
250 }
251
252 pub fn with_options(&self) -> &WithOptions {
253 &self.with_options
254 }
255
256 pub fn overwrite_options(&self) -> &OverwriteOptions {
257 &self.overwrite_options
258 }
259
260 pub fn session_ctx(&self) -> &Arc<SessionImpl> {
261 &self.session_ctx
262 }
263
264 pub fn sql(&self) -> &str {
266 &self.sql
267 }
268
269 pub fn normalized_sql(&self) -> &str {
271 &self.normalized_sql
272 }
273
274 pub fn session_timezone(&self) -> RefMut<'_, SessionTimezone> {
275 self.session_timezone.borrow_mut()
276 }
277
278 pub fn get_session_timezone(&self) -> String {
279 self.session_timezone.borrow().timezone()
280 }
281
282 pub fn get_rcte_cache_plan(&self, id: &ShareId) -> Option<PlanRef> {
283 self.rcte_cache.borrow().get(id).cloned()
284 }
285
286 pub fn insert_rcte_cache_plan(&self, id: ShareId, plan: PlanRef) {
287 self.rcte_cache.borrow_mut().insert(id, plan);
288 }
289}
290
291impl std::fmt::Debug for OptimizerContext {
292 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 write!(
294 f,
295 "QueryContext {{ sql = {}, explain_options = {}, with_options = {:?}, last_plan_node_id = {}, last_correlated_id = {} }}",
296 self.sql,
297 self.explain_options,
298 self.with_options,
299 self.last_plan_node_id.get(),
300 self.last_correlated_id.get(),
301 )
302 }
303}