1use std::cell::RefCell;
16use std::collections::HashMap;
17use std::rc::Rc;
18
19use risingwave_common::catalog::Schema;
20use risingwave_common::types::DataType;
21use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
22use risingwave_sqlparser::ast::{
23 Cte, CteInner, Expr, Fetch, OrderByExpr, Query, SetExpr, SetOperator, Value, With,
24};
25use thiserror_ext::AsReport;
26
27use super::BoundValues;
28use super::bind_context::BindingCteState;
29use super::statement::RewriteExprsRecursive;
30use crate::binder::bind_context::{BindingCte, RecursiveUnion};
31use crate::binder::{Binder, BoundSetExpr};
32use crate::error::{ErrorCode, Result, RwError};
33use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};
34
35#[derive(Debug, Clone)]
38pub struct BoundQuery {
39 pub body: BoundSetExpr,
40 pub order: Vec<ColumnOrder>,
41 pub limit: Option<u64>,
42 pub offset: Option<u64>,
43 pub with_ties: bool,
44 pub extra_order_exprs: Vec<ExprImpl>,
45}
46
47impl BoundQuery {
48 pub fn schema(&self) -> std::borrow::Cow<'_, Schema> {
50 self.body.schema()
51 }
52
53 pub fn data_types(&self) -> Vec<DataType> {
55 self.schema().data_types()
56 }
57
58 pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
89 self.body.is_correlated_by_depth(depth + 1)
90 || self
91 .extra_order_exprs
92 .iter()
93 .any(|e| e.has_correlated_input_ref_by_depth(depth + 1))
94 }
95
96 pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
97 self.body.is_correlated_by_correlated_id(correlated_id)
98 || self
99 .extra_order_exprs
100 .iter()
101 .any(|e| e.has_correlated_input_ref_by_correlated_id(correlated_id))
102 }
103
104 pub fn collect_correlated_indices_by_depth_and_assign_id(
105 &mut self,
106 depth: Depth,
107 correlated_id: CorrelatedId,
108 ) -> Vec<usize> {
109 let mut correlated_indices = vec![];
110
111 correlated_indices.extend(
112 self.body
113 .collect_correlated_indices_by_depth_and_assign_id(depth + 1, correlated_id),
114 );
115
116 correlated_indices.extend(self.extra_order_exprs.iter_mut().flat_map(|expr| {
117 expr.collect_correlated_indices_by_depth_and_assign_id(depth + 1, correlated_id)
118 }));
119 correlated_indices
120 }
121
122 pub fn with_values(values: BoundValues) -> Self {
124 BoundQuery {
125 body: BoundSetExpr::Values(values.into()),
126 order: vec![],
127 limit: None,
128 offset: None,
129 with_ties: false,
130 extra_order_exprs: vec![],
131 }
132 }
133}
134
135impl RewriteExprsRecursive for BoundQuery {
136 fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl ExprRewriter) {
137 let new_extra_order_exprs = std::mem::take(&mut self.extra_order_exprs)
138 .into_iter()
139 .map(|expr| rewriter.rewrite_expr(expr))
140 .collect::<Vec<_>>();
141 self.extra_order_exprs = new_extra_order_exprs;
142
143 self.body.rewrite_exprs_recursive(rewriter);
144 }
145}
146
147impl Binder {
148 pub fn bind_query(&mut self, query: Query) -> Result<BoundQuery> {
155 self.push_context();
156 let result = self.bind_query_inner(query);
157 self.pop_context()?;
158 result
159 }
160
161 pub fn bind_query_for_view(&mut self, query: Query) -> Result<BoundQuery> {
164 self.push_context();
165 self.context.disable_security_invoker = true;
166 let result = self.bind_query_inner(query);
167 self.pop_context()?;
168 result
169 }
170
171 pub(super) fn bind_query_inner(
173 &mut self,
174 Query {
175 with,
176 body,
177 order_by,
178 limit,
179 offset,
180 fetch,
181 }: Query,
182 ) -> Result<BoundQuery> {
183 let mut with_ties = false;
184 let limit = match (limit, fetch) {
185 (None, None) => None,
186 (
187 None,
188 Some(Fetch {
189 with_ties: fetch_with_ties,
190 quantity,
191 }),
192 ) => {
193 with_ties = fetch_with_ties;
194 match quantity {
195 Some(v) => Some(Expr::Value(Value::Number(v))),
196 None => Some(Expr::Value(Value::Number("1".to_owned()))),
197 }
198 }
199 (Some(limit), None) => Some(limit),
200 (Some(_), Some(_)) => unreachable!(), };
202 let limit_expr = limit.map(|expr| self.bind_expr(expr)).transpose()?;
203 let limit = if let Some(limit_expr) = limit_expr {
204 let limit_cast_to_bigint = limit_expr.cast_assign(DataType::Int64).map_err(|_| {
206 RwError::from(ErrorCode::ExprError(
207 "expects an integer or expression that can be evaluated to an integer after LIMIT"
208 .into(),
209 ))
210 })?;
211 let limit = match limit_cast_to_bigint.try_fold_const() {
212 Some(Ok(Some(datum))) => {
213 let value = datum.as_int64();
214 if *value < 0 {
215 return Err(ErrorCode::ExprError(
216 format!("LIMIT must not be negative, but found: {}", *value).into(),
217 )
218 .into());
219 }
220 *value as u64
221 }
222 Some(Ok(None)) => {
224 u64::MAX
225 }
226 None => return Err(ErrorCode::ExprError(
228 "expects an integer or expression that can be evaluated to an integer after LIMIT, but found non-const expression"
229 .into(),
230 ).into()),
231 Some(Err(e)) => {
233 return Err(ErrorCode::ExprError(
234 format!("expects an integer or expression that can be evaluated to an integer after LIMIT,\nbut the evaluation of the expression returns error:{}", e.as_report()
235 ).into(),
236 ).into())
237 }
238 };
239 Some(limit)
240 } else {
241 None
242 };
243
244 let offset = offset
245 .map(|s| parse_non_negative_i64("OFFSET", &s))
246 .transpose()?
247 .map(|v| v as u64);
248
249 if let Some(with) = with {
250 self.bind_with(with)?;
251 }
252 let body = self.bind_set_expr(body)?;
253 let name_to_index =
254 Self::build_name_to_index(body.schema().fields().iter().map(|f| f.name.clone()));
255 let mut extra_order_exprs = vec![];
256 let visible_output_num = body.schema().len();
257 let order = order_by
258 .into_iter()
259 .map(|order_by_expr| {
260 self.bind_order_by_expr_in_query(
261 order_by_expr,
262 &name_to_index,
263 &mut extra_order_exprs,
264 visible_output_num,
265 )
266 })
267 .collect::<Result<_>>()?;
268 Ok(BoundQuery {
269 body,
270 order,
271 limit,
272 offset,
273 with_ties,
274 extra_order_exprs,
275 })
276 }
277
278 pub fn build_name_to_index(names: impl Iterator<Item = String>) -> HashMap<String, usize> {
279 let mut m = HashMap::new();
280 names.enumerate().for_each(|(index, name)| {
281 m.entry(name)
282 .and_modify(|v| *v = usize::MAX)
285 .or_insert(index);
286 });
287 m
288 }
289
290 fn bind_order_by_expr_in_query(
303 &mut self,
304 OrderByExpr {
305 expr,
306 asc,
307 nulls_first,
308 }: OrderByExpr,
309 name_to_index: &HashMap<String, usize>,
310 extra_order_exprs: &mut Vec<ExprImpl>,
311 visible_output_num: usize,
312 ) -> Result<ColumnOrder> {
313 let order_type = OrderType::from_bools(asc, nulls_first);
314 let column_index = match expr {
315 Expr::Identifier(name) if let Some(index) = name_to_index.get(&name.real_value()) => {
316 match *index != usize::MAX {
317 true => *index,
318 false => {
319 return Err(ErrorCode::BindError(format!(
320 "ORDER BY \"{}\" is ambiguous",
321 name.real_value()
322 ))
323 .into());
324 }
325 }
326 }
327 Expr::Value(Value::Number(number)) => match number.parse::<usize>() {
328 Ok(index) if 1 <= index && index <= visible_output_num => index - 1,
329 _ => {
330 return Err(ErrorCode::InvalidInputSyntax(format!(
331 "Invalid ordinal number in ORDER BY: {}",
332 number
333 ))
334 .into());
335 }
336 },
337 expr => {
338 extra_order_exprs.push(self.bind_expr(expr)?);
339 visible_output_num + extra_order_exprs.len() - 1
340 }
341 };
342 Ok(ColumnOrder::new(column_index, order_type))
343 }
344
345 fn bind_with(&mut self, with: With) -> Result<()> {
346 for cte_table in with.cte_tables {
347 let share_id = self.next_share_id();
349 let Cte { alias, cte_inner } = cte_table;
350 let table_name = alias.name.real_value();
351
352 if with.recursive {
353 if let CteInner::Query(query) = cte_inner {
354 let (
355 SetExpr::SetOperation {
356 op: SetOperator::Union,
357 all,
358 corresponding,
359 left,
360 right,
361 },
362 with,
363 ) = Self::validate_rcte(*query)?
364 else {
365 return Err(ErrorCode::BindError(
366 "expect `SetOperation` as the return type of validation".into(),
367 )
368 .into());
369 };
370
371 assert!(
373 !corresponding.is_corresponding(),
374 "`CORRESPONDING` is not supported in recursive CTE"
375 );
376
377 let entry = self
378 .context
379 .cte_to_relation
380 .entry(table_name)
381 .insert_entry(Rc::new(RefCell::new(BindingCte {
382 share_id,
383 state: BindingCteState::Init,
384 alias,
385 })))
386 .get()
387 .clone();
388
389 self.bind_rcte(with, entry, *left, *right, all)?;
390 } else {
391 return Err(ErrorCode::BindError(
392 "RECURSIVE CTE only support query".to_owned(),
393 )
394 .into());
395 }
396 } else {
397 match cte_inner {
398 CteInner::Query(query) => {
399 let bound_query = self.bind_query(*query)?;
400 self.context.cte_to_relation.insert(
401 table_name,
402 Rc::new(RefCell::new(BindingCte {
403 share_id,
404 state: BindingCteState::Bound {
405 query: either::Either::Left(bound_query),
406 },
407 alias,
408 })),
409 );
410 }
411 CteInner::ChangeLog(from_table_name) => {
412 self.push_context();
413 let from_table_relation =
414 self.bind_relation_by_name(from_table_name.clone(), None, None, true)?;
415 self.pop_context()?;
416 self.context.cte_to_relation.insert(
417 table_name,
418 Rc::new(RefCell::new(BindingCte {
419 share_id,
420 state: BindingCteState::ChangeLog {
421 table: from_table_relation,
422 },
423 alias,
424 })),
425 );
426 }
427 }
428 }
429 }
430 Ok(())
431 }
432
433 fn validate_rcte(query: Query) -> Result<(SetExpr, Option<With>)> {
435 let Query {
436 with,
437 body,
438 order_by,
439 limit,
440 offset,
441 fetch,
442 } = query;
443
444 fn should_be_empty<T>(v: Option<T>, clause: &str) -> Result<()> {
446 if v.is_some() {
447 return Err(ErrorCode::BindError(format!(
448 "`{clause}` is not supported in recursive CTE"
449 ))
450 .into());
451 }
452 Ok(())
453 }
454
455 should_be_empty(order_by.first(), "ORDER BY")?;
456 should_be_empty(limit, "LIMIT")?;
457 should_be_empty(offset, "OFFSET")?;
458 should_be_empty(fetch, "FETCH")?;
459
460 let SetExpr::SetOperation {
461 op: SetOperator::Union,
462 all,
463 corresponding,
464 left,
465 right,
466 } = body
467 else {
468 return Err(
469 ErrorCode::BindError("`UNION` is required in recursive CTE".to_owned()).into(),
470 );
471 };
472
473 if !all {
474 return Err(ErrorCode::BindError(
475 "only `UNION ALL` is supported in recursive CTE now".to_owned(),
476 )
477 .into());
478 }
479
480 if corresponding.is_corresponding() {
481 return Err(ErrorCode::BindError(
482 "`CORRESPONDING` is not supported in recursive CTE".to_owned(),
483 )
484 .into());
485 }
486
487 Ok((
488 SetExpr::SetOperation {
489 op: SetOperator::Union,
490 all,
491 corresponding,
492 left,
493 right,
494 },
495 with,
496 ))
497 }
498
499 fn bind_rcte(
500 &mut self,
501 with: Option<With>,
502 entry: Rc<RefCell<BindingCte>>,
503 left: SetExpr,
504 right: SetExpr,
505 all: bool,
506 ) -> Result<()> {
507 self.push_context();
508 let result = self.bind_rcte_inner(with, entry, left, right, all);
509 self.pop_context()?;
510 result
511 }
512
513 fn bind_rcte_inner(
514 &mut self,
515 with: Option<With>,
516 entry: Rc<RefCell<BindingCte>>,
517 left: SetExpr,
518 right: SetExpr,
519 all: bool,
520 ) -> Result<()> {
521 if let Some(with) = with {
522 self.bind_with(with)?;
523 }
524
525 let mut base = self.bind_set_expr(left)?;
529
530 entry.borrow_mut().state = BindingCteState::BaseResolved { base: base.clone() };
531
532 let new_context = std::mem::take(&mut self.context);
534 self.context
535 .cte_to_relation
536 .clone_from(&new_context.cte_to_relation);
537 self.context.disable_security_invoker = new_context.disable_security_invoker;
538 let mut recursive = self.bind_set_expr(right)?;
540 self.context = Default::default();
542 self.context.cte_to_relation = new_context.cte_to_relation;
543 self.context.disable_security_invoker = new_context.disable_security_invoker;
544
545 Self::align_schema(&mut base, &mut recursive, SetOperator::Union)?;
546 let schema = base.schema().into_owned();
547
548 let recursive_union = RecursiveUnion {
549 all,
550 base: Box::new(base),
551 recursive: Box::new(recursive),
552 schema,
553 };
554
555 entry.borrow_mut().state = BindingCteState::Bound {
556 query: either::Either::Right(recursive_union),
557 };
558
559 Ok(())
560 }
561}
562
563fn parse_non_negative_i64(clause: &str, s: &str) -> Result<i64> {
565 match s.parse::<i64>() {
566 Ok(v) => {
567 if v < 0 {
568 Err(ErrorCode::InvalidInputSyntax(format!("{clause} must not be negative")).into())
569 } else {
570 Ok(v)
571 }
572 }
573 Err(e) => Err(ErrorCode::InvalidInputSyntax(e.to_report_string()).into()),
574 }
575}