1use std::cell::RefCell;
16use std::collections::HashMap;
17use std::rc::Rc;
18
19use risingwave_common::catalog::Schema;
20use risingwave_common::types::DataType;
21use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
22use risingwave_sqlparser::ast::{
23 Corresponding, Cte, CteInner, Expr, Fetch, OrderByExpr, Query, SetExpr, SetOperator, Value,
24 With,
25};
26use thiserror_ext::AsReport;
27
28use super::BoundValues;
29use super::bind_context::BindingCteState;
30use super::statement::RewriteExprsRecursive;
31use crate::binder::bind_context::{BindingCte, RecursiveUnion};
32use crate::binder::{Binder, BoundSetExpr};
33use crate::error::{ErrorCode, Result, RwError};
34use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};
35
36#[derive(Debug, Clone)]
39pub struct BoundQuery {
40 pub body: BoundSetExpr,
41 pub order: Vec<ColumnOrder>,
42 pub limit: Option<u64>,
43 pub offset: Option<u64>,
44 pub with_ties: bool,
45 pub extra_order_exprs: Vec<ExprImpl>,
46}
47
48impl BoundQuery {
49 pub fn schema(&self) -> std::borrow::Cow<'_, Schema> {
51 self.body.schema()
52 }
53
54 pub fn data_types(&self) -> Vec<DataType> {
56 self.schema().data_types()
57 }
58
59 pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
90 self.body.is_correlated_by_depth(depth + 1)
91 || self
92 .extra_order_exprs
93 .iter()
94 .any(|e| e.has_correlated_input_ref_by_depth(depth + 1))
95 }
96
97 pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
98 self.body.is_correlated_by_correlated_id(correlated_id)
99 || self
100 .extra_order_exprs
101 .iter()
102 .any(|e| e.has_correlated_input_ref_by_correlated_id(correlated_id))
103 }
104
105 pub fn collect_correlated_indices_by_depth_and_assign_id(
106 &mut self,
107 depth: Depth,
108 correlated_id: CorrelatedId,
109 ) -> Vec<usize> {
110 let mut correlated_indices = vec![];
111
112 correlated_indices.extend(
113 self.body
114 .collect_correlated_indices_by_depth_and_assign_id(depth + 1, correlated_id),
115 );
116
117 correlated_indices.extend(self.extra_order_exprs.iter_mut().flat_map(|expr| {
118 expr.collect_correlated_indices_by_depth_and_assign_id(depth + 1, correlated_id)
119 }));
120 correlated_indices
121 }
122
123 pub fn with_values(values: BoundValues) -> Self {
125 BoundQuery {
126 body: BoundSetExpr::Values(values.into()),
127 order: vec![],
128 limit: None,
129 offset: None,
130 with_ties: false,
131 extra_order_exprs: vec![],
132 }
133 }
134}
135
136impl RewriteExprsRecursive for BoundQuery {
137 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl ExprRewriter) {
138 let new_extra_order_exprs = std::mem::take(&mut self.extra_order_exprs)
139 .into_iter()
140 .map(|expr| rewriter.rewrite_expr(expr))
141 .collect::<Vec<_>>();
142 self.extra_order_exprs = new_extra_order_exprs;
143
144 self.body.rewrite_exprs_recursive(rewriter);
145 }
146}
147
148impl Binder {
149 pub fn bind_query(&mut self, query: &Query) -> Result<BoundQuery> {
156 self.push_context();
157 let result = self.bind_query_inner(query);
158 self.pop_context()?;
159 result
160 }
161
162 pub fn bind_query_for_view(&mut self, query: &Query) -> Result<BoundQuery> {
165 self.push_context();
166 self.context.disable_security_invoker = true;
167 let result = self.bind_query_inner(query);
168 self.pop_context()?;
169 result
170 }
171
172 pub(super) fn bind_query_inner(
174 &mut self,
175 Query {
176 with,
177 body,
178 order_by,
179 limit,
180 offset,
181 fetch,
182 }: &Query,
183 ) -> Result<BoundQuery> {
184 let mut with_ties = false;
185 let limit = match (limit, fetch) {
186 (None, None) => None,
187 (
188 None,
189 Some(Fetch {
190 with_ties: fetch_with_ties,
191 quantity,
192 }),
193 ) => {
194 with_ties = *fetch_with_ties;
195 match quantity {
196 Some(v) => Some(Expr::Value(Value::Number(v.clone()))),
197 None => Some(Expr::Value(Value::Number("1".to_owned()))),
198 }
199 }
200 (Some(limit), None) => Some(limit.clone()),
201 (Some(_), Some(_)) => unreachable!(), };
203 let limit_expr = limit.map(|expr| self.bind_expr(&expr)).transpose()?;
204 let limit = if let Some(limit_expr) = limit_expr {
205 let limit_cast_to_bigint = limit_expr.cast_assign(&DataType::Int64).map_err(|_| {
207 RwError::from(ErrorCode::ExprError(
208 "expects an integer or expression that can be evaluated to an integer after LIMIT"
209 .into(),
210 ))
211 })?;
212 let limit = match limit_cast_to_bigint.try_fold_const() {
213 Some(Ok(Some(datum))) => {
214 let value = datum.as_int64();
215 if *value < 0 {
216 return Err(ErrorCode::ExprError(
217 format!("LIMIT must not be negative, but found: {}", *value).into(),
218 )
219 .into());
220 }
221 *value as u64
222 }
223 Some(Ok(None)) => {
225 u64::MAX
226 }
227 None => return Err(ErrorCode::ExprError(
229 "expects an integer or expression that can be evaluated to an integer after LIMIT, but found non-const expression"
230 .into(),
231 ).into()),
232 Some(Err(e)) => {
234 return Err(ErrorCode::ExprError(
235 format!("expects an integer or expression that can be evaluated to an integer after LIMIT,\nbut the evaluation of the expression returns error:{}", e.as_report()
236 ).into(),
237 ).into())
238 }
239 };
240 Some(limit)
241 } else {
242 None
243 };
244
245 let offset = offset
246 .as_ref()
247 .map(|s| parse_non_negative_i64("OFFSET", s))
248 .transpose()?
249 .map(|v| v as u64);
250
251 if let Some(with) = with {
252 self.bind_with(with)?;
253 }
254 let body = self.bind_set_expr(body)?;
255 let name_to_index =
256 Self::build_name_to_index(body.schema().fields().iter().map(|f| f.name.clone()));
257 let mut extra_order_exprs = vec![];
258 let visible_output_num = body.schema().len();
259 let order = order_by
260 .iter()
261 .map(|order_by_expr| {
262 self.bind_order_by_expr_in_query(
263 order_by_expr,
264 &name_to_index,
265 &mut extra_order_exprs,
266 visible_output_num,
267 )
268 })
269 .collect::<Result<_>>()?;
270 Ok(BoundQuery {
271 body,
272 order,
273 limit,
274 offset,
275 with_ties,
276 extra_order_exprs,
277 })
278 }
279
280 pub fn build_name_to_index(names: impl Iterator<Item = String>) -> HashMap<String, usize> {
281 let mut m = HashMap::new();
282 names.enumerate().for_each(|(index, name)| {
283 m.entry(name)
284 .and_modify(|v| *v = usize::MAX)
287 .or_insert(index);
288 });
289 m
290 }
291
292 fn bind_order_by_expr_in_query(
305 &mut self,
306 OrderByExpr {
307 expr,
308 asc,
309 nulls_first,
310 }: &OrderByExpr,
311 name_to_index: &HashMap<String, usize>,
312 extra_order_exprs: &mut Vec<ExprImpl>,
313 visible_output_num: usize,
314 ) -> Result<ColumnOrder> {
315 let order_type = OrderType::from_bools(*asc, *nulls_first);
316 let column_index = match expr {
317 Expr::Identifier(name) if let Some(index) = name_to_index.get(&name.real_value()) => {
318 match *index != usize::MAX {
319 true => *index,
320 false => {
321 return Err(ErrorCode::BindError(format!(
322 "ORDER BY \"{}\" is ambiguous",
323 name.real_value()
324 ))
325 .into());
326 }
327 }
328 }
329 Expr::Value(Value::Number(number)) => match number.parse::<usize>() {
330 Ok(index) if 1 <= index && index <= visible_output_num => index - 1,
331 _ => {
332 return Err(ErrorCode::InvalidInputSyntax(format!(
333 "Invalid ordinal number in ORDER BY: {}",
334 number
335 ))
336 .into());
337 }
338 },
339 expr => {
340 extra_order_exprs.push(self.bind_expr(expr)?);
341 visible_output_num + extra_order_exprs.len() - 1
342 }
343 };
344 Ok(ColumnOrder::new(column_index, order_type))
345 }
346
347 fn bind_with(&mut self, with: &With) -> Result<()> {
348 for cte_table in &with.cte_tables {
349 let share_id = self.next_share_id();
351 let Cte { alias, cte_inner } = cte_table;
352 let table_name = alias.name.real_value();
353
354 if with.recursive {
355 if let CteInner::Query(query) = cte_inner {
356 let (all, corresponding, left, right, with) = Self::validate_rcte(query)?;
357
358 assert!(
360 !corresponding.is_corresponding(),
361 "`CORRESPONDING` is not supported in recursive CTE"
362 );
363
364 let entry = self
365 .context
366 .cte_to_relation
367 .entry(table_name)
368 .insert_entry(Rc::new(RefCell::new(BindingCte {
369 share_id,
370 state: BindingCteState::Init,
371 alias: alias.clone(),
372 })))
373 .get()
374 .clone();
375
376 self.bind_rcte(with, entry, left, right, all)?;
377 } else {
378 return Err(ErrorCode::BindError(
379 "RECURSIVE CTE only support query".to_owned(),
380 )
381 .into());
382 }
383 } else {
384 match cte_inner {
385 CteInner::Query(query) => {
386 let bound_query = self.bind_query(query)?;
387 self.context.cte_to_relation.insert(
388 table_name,
389 Rc::new(RefCell::new(BindingCte {
390 share_id,
391 state: BindingCteState::Bound {
392 query: either::Either::Left(bound_query),
393 },
394 alias: alias.clone(),
395 })),
396 );
397 }
398 CteInner::ChangeLog(from_table_name) => {
399 self.push_context();
400 let from_table_relation =
401 self.bind_relation_by_name(from_table_name, None, None, true)?;
402 self.pop_context()?;
403 self.context.cte_to_relation.insert(
404 table_name,
405 Rc::new(RefCell::new(BindingCte {
406 share_id,
407 state: BindingCteState::ChangeLog {
408 table: from_table_relation,
409 },
410 alias: alias.clone(),
411 })),
412 );
413 }
414 }
415 }
416 }
417 Ok(())
418 }
419
420 fn validate_rcte(
422 query: &Query,
423 ) -> Result<(bool, &Corresponding, &SetExpr, &SetExpr, Option<&With>)> {
424 let Query {
425 with,
426 body,
427 order_by,
428 limit,
429 offset,
430 fetch,
431 } = query;
432
433 fn should_be_empty<T>(v: Option<T>, clause: &str) -> Result<()> {
435 if v.is_some() {
436 return Err(ErrorCode::BindError(format!(
437 "`{clause}` is not supported in recursive CTE"
438 ))
439 .into());
440 }
441 Ok(())
442 }
443
444 should_be_empty(order_by.first(), "ORDER BY")?;
445 should_be_empty(limit.as_ref(), "LIMIT")?;
446 should_be_empty(offset.as_ref(), "OFFSET")?;
447 should_be_empty(fetch.as_ref(), "FETCH")?;
448
449 let SetExpr::SetOperation {
450 op: SetOperator::Union,
451 all,
452 corresponding,
453 left,
454 right,
455 } = body
456 else {
457 return Err(
458 ErrorCode::BindError("`UNION` is required in recursive CTE".to_owned()).into(),
459 );
460 };
461
462 if !all {
463 return Err(ErrorCode::BindError(
464 "only `UNION ALL` is supported in recursive CTE now".to_owned(),
465 )
466 .into());
467 }
468
469 if corresponding.is_corresponding() {
470 return Err(ErrorCode::BindError(
471 "`CORRESPONDING` is not supported in recursive CTE".to_owned(),
472 )
473 .into());
474 }
475
476 Ok((*all, corresponding, left, right, with.as_ref()))
477 }
478
479 fn bind_rcte(
480 &mut self,
481 with: Option<&With>,
482 entry: Rc<RefCell<BindingCte>>,
483 left: &SetExpr,
484 right: &SetExpr,
485 all: bool,
486 ) -> Result<()> {
487 self.push_context();
488 let result = self.bind_rcte_inner(with, entry, left, right, all);
489 self.pop_context()?;
490 result
491 }
492
493 fn bind_rcte_inner(
494 &mut self,
495 with: Option<&With>,
496 entry: Rc<RefCell<BindingCte>>,
497 left: &SetExpr,
498 right: &SetExpr,
499 all: bool,
500 ) -> Result<()> {
501 if let Some(with) = with {
502 self.bind_with(with)?;
503 }
504
505 let mut base = self.bind_set_expr(left)?;
509
510 entry.borrow_mut().state = BindingCteState::BaseResolved { base: base.clone() };
511
512 let new_context = std::mem::take(&mut self.context);
514 self.context
515 .cte_to_relation
516 .clone_from(&new_context.cte_to_relation);
517 self.context.disable_security_invoker = new_context.disable_security_invoker;
518 let mut recursive = self.bind_set_expr(right)?;
520 self.context = Default::default();
522 self.context.cte_to_relation = new_context.cte_to_relation;
523 self.context.disable_security_invoker = new_context.disable_security_invoker;
524
525 Self::align_schema(&mut base, &mut recursive, SetOperator::Union)?;
526 let schema = base.schema().into_owned();
527
528 let recursive_union = RecursiveUnion {
529 all,
530 base: Box::new(base),
531 recursive: Box::new(recursive),
532 schema,
533 };
534
535 entry.borrow_mut().state = BindingCteState::Bound {
536 query: either::Either::Right(recursive_union),
537 };
538
539 Ok(())
540 }
541}
542
543fn parse_non_negative_i64(clause: &str, s: &str) -> Result<i64> {
545 match s.parse::<i64>() {
546 Ok(v) => {
547 if v < 0 {
548 Err(ErrorCode::InvalidInputSyntax(format!("{clause} must not be negative")).into())
549 } else {
550 Ok(v)
551 }
552 }
553 Err(e) => Err(ErrorCode::InvalidInputSyntax(e.to_report_string()).into()),
554 }
555}