risingwave_frontend/optimizer/
optimizer_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt::Formatter;
16use std::cell::{Cell, RefCell, RefMut};
17use std::collections::HashMap;
18use std::marker::PhantomData;
19use std::rc::Rc;
20use std::sync::Arc;
21
22use risingwave_sqlparser::ast::{ExplainFormat, ExplainOptions, ExplainType};
23
24use super::property::WatermarkGroupId;
25use crate::binder::ShareId;
26use crate::expr::{CorrelatedId, SessionTimezone};
27use crate::handler::HandlerArgs;
28use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, PlanNodeId};
29use crate::session::SessionImpl;
30use crate::utils::{OverwriteOptions, WithOptions};
31
32const RESERVED_ID_NUM: u16 = 10000;
33
34type PhantomUnsend = PhantomData<Rc<()>>;
35
36pub struct OptimizerContext {
37    session_ctx: Arc<SessionImpl>,
38    /// The original SQL string, used for debugging.
39    sql: Arc<str>,
40    /// Normalized SQL string. See [`HandlerArgs::normalize_sql`].
41    normalized_sql: String,
42    /// Explain options
43    explain_options: ExplainOptions,
44    /// Store the trace of optimizer
45    optimizer_trace: RefCell<Vec<String>>,
46    /// Store the optimized logical plan of optimizer
47    logical_explain: RefCell<Option<String>>,
48    /// Store options or properties from the `with` clause
49    with_options: WithOptions,
50    /// Store the Session Timezone and whether it was used.
51    session_timezone: RefCell<SessionTimezone>,
52    /// Total number of optimization rules have been applied.
53    total_rule_applied: RefCell<usize>,
54    /// Store the configs can be overwritten in with clause
55    /// if not specified, use the value from session variable.
56    overwrite_options: OverwriteOptions,
57    /// Store the mapping between `share_id` and the corresponding
58    /// `PlanRef`, used by rcte's planning. (e.g., in `LogicalCteRef`)
59    rcte_cache: RefCell<HashMap<ShareId, PlanRef>>,
60
61    /// Last assigned plan node ID.
62    last_plan_node_id: Cell<i32>,
63    /// Last assigned correlated ID.
64    last_correlated_id: Cell<u32>,
65    /// Last assigned expr display ID.
66    last_expr_display_id: Cell<usize>,
67    /// Last assigned watermark group ID.
68    last_watermark_group_id: Cell<u32>,
69
70    _phantom: PhantomUnsend,
71}
72
73pub(in crate::optimizer) struct LastAssignedIds {
74    last_plan_node_id: i32,
75    last_correlated_id: u32,
76    last_expr_display_id: usize,
77    last_watermark_group_id: u32,
78}
79
80pub type OptimizerContextRef = Rc<OptimizerContext>;
81
82impl OptimizerContext {
83    /// Create a new [`OptimizerContext`] from the given [`HandlerArgs`], with empty
84    /// [`ExplainOptions`].
85    pub fn from_handler_args(handler_args: HandlerArgs) -> Self {
86        Self::new(handler_args, ExplainOptions::default())
87    }
88
89    /// Create a new [`OptimizerContext`] from the given [`HandlerArgs`] and [`ExplainOptions`].
90    pub fn new(mut handler_args: HandlerArgs, explain_options: ExplainOptions) -> Self {
91        let session_timezone = RefCell::new(SessionTimezone::new(
92            handler_args.session.config().timezone().to_owned(),
93        ));
94        let overwrite_options = OverwriteOptions::new(&mut handler_args);
95        Self {
96            session_ctx: handler_args.session,
97            sql: handler_args.sql,
98            normalized_sql: handler_args.normalized_sql,
99            explain_options,
100            optimizer_trace: RefCell::new(vec![]),
101            logical_explain: RefCell::new(None),
102            with_options: handler_args.with_options,
103            session_timezone,
104            total_rule_applied: RefCell::new(0),
105            overwrite_options,
106            rcte_cache: RefCell::new(HashMap::new()),
107
108            last_plan_node_id: Cell::new(RESERVED_ID_NUM.into()),
109            last_correlated_id: Cell::new(0),
110            last_expr_display_id: Cell::new(RESERVED_ID_NUM.into()),
111            last_watermark_group_id: Cell::new(RESERVED_ID_NUM.into()),
112
113            _phantom: Default::default(),
114        }
115    }
116
117    // TODO(TaoWu): Remove the async.
118    #[cfg(test)]
119    #[expect(clippy::unused_async)]
120    pub async fn mock() -> OptimizerContextRef {
121        Self {
122            session_ctx: Arc::new(SessionImpl::mock()),
123            sql: Arc::from(""),
124            normalized_sql: "".to_owned(),
125            explain_options: ExplainOptions::default(),
126            optimizer_trace: RefCell::new(vec![]),
127            logical_explain: RefCell::new(None),
128            with_options: Default::default(),
129            session_timezone: RefCell::new(SessionTimezone::new("UTC".into())),
130            total_rule_applied: RefCell::new(0),
131            overwrite_options: OverwriteOptions::default(),
132            rcte_cache: RefCell::new(HashMap::new()),
133
134            last_plan_node_id: Cell::new(0),
135            last_correlated_id: Cell::new(0),
136            last_expr_display_id: Cell::new(0),
137            last_watermark_group_id: Cell::new(0),
138
139            _phantom: Default::default(),
140        }
141        .into()
142    }
143
144    pub fn next_plan_node_id(&self) -> PlanNodeId {
145        self.last_plan_node_id.update(|id| id + 1);
146        PlanNodeId(self.last_plan_node_id.get())
147    }
148
149    pub fn next_correlated_id(&self) -> CorrelatedId {
150        self.last_correlated_id.update(|id| id + 1);
151        self.last_correlated_id.get()
152    }
153
154    pub fn next_expr_display_id(&self) -> usize {
155        self.last_expr_display_id.update(|id| id + 1);
156        self.last_expr_display_id.get()
157    }
158
159    pub fn next_watermark_group_id(&self) -> WatermarkGroupId {
160        self.last_watermark_group_id.update(|id| id + 1);
161        self.last_watermark_group_id.get()
162    }
163
164    pub(in crate::optimizer) fn backup_elem_ids(&self) -> LastAssignedIds {
165        LastAssignedIds {
166            last_plan_node_id: self.last_plan_node_id.get(),
167            last_correlated_id: self.last_correlated_id.get(),
168            last_expr_display_id: self.last_expr_display_id.get(),
169            last_watermark_group_id: self.last_watermark_group_id.get(),
170        }
171    }
172
173    /// This should only be called in [`crate::optimizer::plan_node::reorganize_elements_id`].
174    pub(in crate::optimizer) fn reset_elem_ids(&self) {
175        self.last_plan_node_id.set(0);
176        self.last_correlated_id.set(0);
177        self.last_expr_display_id.set(0);
178        self.last_watermark_group_id.set(0);
179    }
180
181    pub(in crate::optimizer) fn restore_elem_ids(&self, backup: LastAssignedIds) {
182        self.last_plan_node_id.set(backup.last_plan_node_id);
183        self.last_correlated_id.set(backup.last_correlated_id);
184        self.last_expr_display_id.set(backup.last_expr_display_id);
185        self.last_watermark_group_id
186            .set(backup.last_watermark_group_id);
187    }
188
189    pub fn add_rule_applied(&self, num: usize) {
190        *self.total_rule_applied.borrow_mut() += num;
191    }
192
193    pub fn total_rule_applied(&self) -> usize {
194        *self.total_rule_applied.borrow()
195    }
196
197    pub fn is_explain_verbose(&self) -> bool {
198        self.explain_options.verbose
199    }
200
201    pub fn is_explain_trace(&self) -> bool {
202        self.explain_options.trace
203    }
204
205    pub fn is_explain_backfill(&self) -> bool {
206        self.explain_options.backfill
207    }
208
209    pub fn explain_type(&self) -> ExplainType {
210        self.explain_options.explain_type.clone()
211    }
212
213    pub fn explain_format(&self) -> ExplainFormat {
214        self.explain_options.explain_format.clone()
215    }
216
217    pub fn is_explain_logical(&self) -> bool {
218        self.explain_type() == ExplainType::Logical
219    }
220
221    pub fn trace(&self, str: impl Into<String>) {
222        // If explain type is logical, do not store the trace for any optimizations beyond logical.
223        if self.is_explain_logical() && self.logical_explain.borrow().is_some() {
224            return;
225        }
226        let mut optimizer_trace = self.optimizer_trace.borrow_mut();
227        let string = str.into();
228        tracing::info!(target: "explain_trace", "\n{}", string);
229        optimizer_trace.push(string);
230        optimizer_trace.push("\n".to_owned());
231    }
232
233    pub fn warn_to_user(&self, str: impl Into<String>) {
234        self.session_ctx().notice_to_user(str);
235    }
236
237    pub fn store_logical(&self, str: impl Into<String>) {
238        *self.logical_explain.borrow_mut() = Some(str.into())
239    }
240
241    pub fn take_logical(&self) -> Option<String> {
242        self.logical_explain.borrow_mut().take()
243    }
244
245    pub fn take_trace(&self) -> Vec<String> {
246        self.optimizer_trace.borrow_mut().drain(..).collect()
247    }
248
249    pub fn with_options(&self) -> &WithOptions {
250        &self.with_options
251    }
252
253    pub fn overwrite_options(&self) -> &OverwriteOptions {
254        &self.overwrite_options
255    }
256
257    pub fn session_ctx(&self) -> &Arc<SessionImpl> {
258        &self.session_ctx
259    }
260
261    /// Return the original SQL.
262    pub fn sql(&self) -> &str {
263        &self.sql
264    }
265
266    /// Return the normalized SQL.
267    pub fn normalized_sql(&self) -> &str {
268        &self.normalized_sql
269    }
270
271    pub fn session_timezone(&self) -> RefMut<'_, SessionTimezone> {
272        self.session_timezone.borrow_mut()
273    }
274
275    pub fn get_session_timezone(&self) -> String {
276        self.session_timezone.borrow().timezone()
277    }
278
279    pub fn get_rcte_cache_plan(&self, id: &ShareId) -> Option<PlanRef> {
280        self.rcte_cache.borrow().get(id).cloned()
281    }
282
283    pub fn insert_rcte_cache_plan(&self, id: ShareId, plan: PlanRef) {
284        self.rcte_cache.borrow_mut().insert(id, plan);
285    }
286}
287
288impl std::fmt::Debug for OptimizerContext {
289    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
290        write!(
291            f,
292            "QueryContext {{ sql = {}, explain_options = {}, with_options = {:?}, last_plan_node_id = {}, last_correlated_id = {} }}",
293            self.sql,
294            self.explain_options,
295            self.with_options,
296            self.last_plan_node_id.get(),
297            self.last_correlated_id.get(),
298        )
299    }
300}