1use std::fmt;
21
22use risingwave_sqlparser::ast::*;
23use strum::EnumDiscriminants;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum AstField {
29 Query,
31 Name,
32 Columns,
33
34 Body,
36 With,
37 OrderBy,
38 Limit,
39 Offset,
40
41 Projection,
43 Selection,
44 From,
45 GroupBy,
46 Having,
47 Distinct,
48
49 Left,
51 Right,
52 Operand,
53 ElseResult,
54 Subquery,
55 Inner,
56 Expr,
57 Low,
58 High,
59
60 Relation,
62 Joins,
63
64 JoinOperator,
66
67 Alias,
69
70 Asc,
72 NullsFirst,
73
74 CteTable,
76 Recursive,
77
78 CteInner,
80
81 Lateral,
83}
84
85impl fmt::Display for AstField {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 let s = match self {
88 AstField::Query => "query",
89 AstField::Name => "name",
90 AstField::Columns => "columns",
91 AstField::Body => "body",
92 AstField::With => "with",
93 AstField::OrderBy => "order_by",
94 AstField::Limit => "limit",
95 AstField::Offset => "offset",
96 AstField::Projection => "projection",
97 AstField::Selection => "selection",
98 AstField::From => "from",
99 AstField::GroupBy => "group_by",
100 AstField::Having => "having",
101 AstField::Distinct => "distinct",
102 AstField::Left => "left",
103 AstField::Right => "right",
104 AstField::Operand => "operand",
105 AstField::ElseResult => "else_result",
106 AstField::Subquery => "subquery",
107 AstField::Inner => "inner",
108 AstField::Expr => "expr",
109 AstField::Low => "low",
110 AstField::High => "high",
111 AstField::Relation => "relation",
112 AstField::Joins => "joins",
113 AstField::JoinOperator => "join_operator",
114 AstField::Alias => "alias",
115 AstField::Asc => "asc",
116 AstField::NullsFirst => "nulls_first",
117 AstField::CteTable => "cte_tables",
118 AstField::Recursive => "recursive",
119 AstField::CteInner => "cte_inner",
120 AstField::Lateral => "lateral",
121 };
122 write!(f, "{}", s)
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
130pub enum PathComponent {
131 Field(AstField),
133 Index(usize),
135}
136
137impl PathComponent {
138 pub fn field(field: AstField) -> Self {
140 PathComponent::Field(field)
141 }
142
143 pub fn field_name(&self) -> Option<String> {
145 match self {
146 PathComponent::Field(field) => Some(field.to_string()),
147 PathComponent::Index(_) => None,
148 }
149 }
150
151 pub fn as_ast_field(&self) -> Option<&AstField> {
153 match self {
154 PathComponent::Field(field) => Some(field),
155 PathComponent::Index(_) => None,
156 }
157 }
158
159 pub fn query() -> Self {
161 Self::field(AstField::Query)
162 }
163
164 pub fn body() -> Self {
165 Self::field(AstField::Body)
166 }
167
168 pub fn selection() -> Self {
169 Self::field(AstField::Selection)
170 }
171
172 pub fn projection() -> Self {
173 Self::field(AstField::Projection)
174 }
175
176 pub fn from_clause() -> Self {
177 Self::field(AstField::From)
178 }
179
180 pub fn group_by() -> Self {
181 Self::field(AstField::GroupBy)
182 }
183
184 pub fn having() -> Self {
185 Self::field(AstField::Having)
186 }
187
188 pub fn with_clause() -> Self {
189 Self::field(AstField::With)
190 }
191
192 pub fn order_by() -> Self {
193 Self::field(AstField::OrderBy)
194 }
195
196 pub fn left() -> Self {
197 Self::field(AstField::Left)
198 }
199
200 pub fn right() -> Self {
201 Self::field(AstField::Right)
202 }
203
204 pub fn operand() -> Self {
205 Self::field(AstField::Operand)
206 }
207
208 pub fn else_result() -> Self {
209 Self::field(AstField::ElseResult)
210 }
211
212 pub fn relation() -> Self {
213 Self::field(AstField::Relation)
214 }
215
216 pub fn joins() -> Self {
217 Self::field(AstField::Joins)
218 }
219
220 pub fn subquery() -> Self {
221 Self::field(AstField::Subquery)
222 }
223
224 pub fn cte_tables() -> Self {
225 Self::field(AstField::CteTable)
226 }
227
228 pub fn cte_inner() -> Self {
229 Self::field(AstField::CteInner)
230 }
231}
232
233impl fmt::Display for PathComponent {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 match self {
236 PathComponent::Field(field) => write!(f, ".{}", field),
237 PathComponent::Index(idx) => write!(f, "[{}]", idx),
238 }
239 }
240}
241
242pub type AstPath = Vec<PathComponent>;
245
246pub fn display_ast_path(path: &AstPath) -> String {
248 path.iter().map(|c| c.to_string()).collect::<String>()
249}
250
251#[derive(Debug, Clone, EnumDiscriminants)]
255#[strum_discriminants(derive(strum::Display))]
256#[strum_discriminants(name(AstNodeType))]
257pub enum AstNode {
258 Statement(Statement),
259 Query(Box<Query>),
260 Select(Box<Select>),
261 Expr(Expr),
262 SelectItem(SelectItem),
263 TableWithJoins(TableWithJoins),
264 Join(Join),
265 TableFactor(TableFactor),
266 OrderByExpr(OrderByExpr),
267 With(With),
268 Cte(Cte),
269 ExprList(Vec<Expr>),
270 SelectItemList(Vec<SelectItem>),
271 TableList(Vec<TableWithJoins>),
272 JoinList(Vec<Join>),
273 OrderByList(Vec<OrderByExpr>),
274 CteList(Vec<Cte>),
275 Option(Option<Box<AstNode>>),
276}
277
278impl AstNode {
279 pub fn get_child(&self, component: &PathComponent) -> Option<AstNode> {
281 match (self, component) {
282 (AstNode::Statement(Statement::Query(query)), PathComponent::Field(field))
284 if *field == AstField::Query =>
285 {
286 Some(AstNode::Query(query.clone()))
287 }
288
289 (
290 AstNode::Statement(Statement::CreateView { query, .. }),
291 PathComponent::Field(field),
292 ) if *field == AstField::Query => Some(AstNode::Query(query.clone())),
293
294 (
296 AstNode::Statement(Statement::CreateView {
297 name: _,
298 columns: _,
299 ..
300 }),
301 PathComponent::Field(field),
302 ) => match field {
303 AstField::Name => None, AstField::Columns => None, _ => None,
306 },
307
308 (AstNode::Query(query), PathComponent::Field(field)) => match field {
310 AstField::Body => match &query.body {
311 SetExpr::Select(select) => Some(AstNode::Select(select.clone())),
312 SetExpr::Query(subquery) => Some(AstNode::Query(subquery.clone())),
313 SetExpr::SetOperation { left, .. } => {
314 match left.as_ref() {
316 SetExpr::Select(select) => Some(AstNode::Select(select.clone())),
317 SetExpr::Query(subquery) => Some(AstNode::Query(subquery.clone())),
318 _ => None,
319 }
320 }
321 _ => None,
322 },
323 AstField::With => query.with.as_ref().map(|w| AstNode::With(w.clone())),
324 AstField::OrderBy => {
325 if query.order_by.is_empty() {
326 None
327 } else {
328 Some(AstNode::OrderByList(query.order_by.clone()))
329 }
330 }
331 AstField::Limit => query.limit.as_ref().map(|e| AstNode::Expr(e.clone())),
332 AstField::Offset => None, _ => None,
334 },
335
336 (AstNode::Select(select), PathComponent::Field(field)) => match field {
338 AstField::Projection => {
339 if select.projection.is_empty() {
340 None
341 } else {
342 Some(AstNode::SelectItemList(select.projection.clone()))
343 }
344 }
345 AstField::Selection => select.selection.as_ref().map(|e| AstNode::Expr(e.clone())),
346 AstField::From => {
347 if select.from.is_empty() {
348 None
349 } else {
350 Some(AstNode::TableList(select.from.clone()))
351 }
352 }
353 AstField::GroupBy => {
354 if select.group_by.is_empty() {
355 None
356 } else {
357 Some(AstNode::ExprList(select.group_by.clone()))
358 }
359 }
360 AstField::Having => select.having.as_ref().map(|e| AstNode::Expr(e.clone())),
361 AstField::Distinct => None, _ => None,
363 },
364
365 (AstNode::SelectItemList(items), PathComponent::Index(idx)) => items
367 .get(*idx)
368 .map(|item| AstNode::SelectItem(item.clone())),
369 (AstNode::ExprList(exprs), PathComponent::Index(idx)) => {
370 exprs.get(*idx).map(|expr| AstNode::Expr(expr.clone()))
371 }
372 (AstNode::TableList(tables), PathComponent::Index(idx)) => tables
373 .get(*idx)
374 .map(|table| AstNode::TableWithJoins(table.clone())),
375 (AstNode::OrderByList(orders), PathComponent::Index(idx)) => orders
376 .get(*idx)
377 .map(|order| AstNode::OrderByExpr(order.clone())),
378
379 (AstNode::Expr(expr), PathComponent::Field(field)) => match (expr, field) {
381 (Expr::BinaryOp { left, .. }, AstField::Left) => Some(AstNode::Expr(*left.clone())),
382 (Expr::BinaryOp { right, .. }, AstField::Right) => {
383 Some(AstNode::Expr(*right.clone()))
384 }
385 (Expr::Case { operand, .. }, AstField::Operand) => {
386 operand.as_ref().map(|e| AstNode::Expr(*e.clone()))
387 }
388 (Expr::Case { else_result, .. }, AstField::ElseResult) => {
389 else_result.as_ref().map(|e| AstNode::Expr(*e.clone()))
390 }
391 (Expr::Exists(subquery), AstField::Subquery) => {
392 Some(AstNode::Query(subquery.clone()))
393 }
394 (Expr::Subquery(subquery), AstField::Subquery) => {
395 Some(AstNode::Query(subquery.clone()))
396 }
397 (Expr::Function(_func), AstField::Name) => None, (Expr::Nested(inner), AstField::Inner) => Some(AstNode::Expr(*inner.clone())),
399 (Expr::UnaryOp { expr, .. }, AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
400 (Expr::Cast { expr, .. }, AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
401 (Expr::IsNull(expr), AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
402 (Expr::IsNotNull(expr), AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
403 (
404 Expr::Between {
405 expr,
406 low: _,
407 high: _,
408 ..
409 },
410 AstField::Expr,
411 ) => Some(AstNode::Expr(*expr.clone())),
412 (Expr::Between { low, .. }, AstField::Low) => Some(AstNode::Expr(*low.clone())),
413 (Expr::Between { high, .. }, AstField::High) => Some(AstNode::Expr(*high.clone())),
414 _ => None,
415 },
416
417 (AstNode::TableWithJoins(table_with_joins), PathComponent::Field(field)) => match field
419 {
420 AstField::Relation => Some(AstNode::TableFactor(table_with_joins.relation.clone())),
421 AstField::Joins => {
422 if !table_with_joins.joins.is_empty() {
423 Some(AstNode::JoinList(table_with_joins.joins.clone()))
424 } else {
425 None
426 }
427 }
428 _ => None,
429 },
430
431 (AstNode::Join(join), PathComponent::Field(field)) => match field {
433 AstField::Relation => Some(AstNode::TableFactor(join.relation.clone())),
434 AstField::JoinOperator => None, _ => None,
436 },
437
438 (AstNode::TableFactor(table_factor), PathComponent::Field(field)) => {
440 match (table_factor, field) {
441 (TableFactor::Table { .. }, _) => None, (TableFactor::Derived { subquery, .. }, AstField::Subquery) => {
443 Some(AstNode::Query(subquery.clone()))
444 }
445 (TableFactor::TableFunction { .. }, _) => None, _ => None,
447 }
448 }
449
450 (AstNode::JoinList(joins), PathComponent::Index(idx)) => {
452 joins.get(*idx).map(|join| AstNode::Join(join.clone()))
453 }
454
455 (AstNode::SelectItem(select_item), PathComponent::Field(field)) => {
457 match (select_item, field) {
458 (SelectItem::UnnamedExpr(expr), AstField::Expr) => {
459 Some(AstNode::Expr(expr.clone()))
460 }
461 (SelectItem::ExprWithAlias { expr, .. }, AstField::Expr) => {
462 Some(AstNode::Expr(expr.clone()))
463 }
464 (SelectItem::QualifiedWildcard(..), _) => None,
465 (SelectItem::Wildcard(..), _) => None,
466 _ => None,
467 }
468 }
469
470 (AstNode::OrderByExpr(order_by), PathComponent::Field(field)) => match field {
472 AstField::Expr => Some(AstNode::Expr(order_by.expr.clone())),
473 AstField::Asc => None, AstField::NullsFirst => None, _ => None,
476 },
477
478 (AstNode::With(with_clause), PathComponent::Field(field)) => match field {
480 AstField::CteTable => {
481 if with_clause.cte_tables.is_empty() {
482 None
483 } else {
484 Some(AstNode::CteList(with_clause.cte_tables.clone()))
485 }
486 }
487 AstField::Recursive => None, _ => None,
489 },
490
491 (AstNode::CteList(ctes), PathComponent::Index(idx)) => {
493 if *idx < ctes.len() {
494 Some(AstNode::Cte(ctes[*idx].clone()))
495 } else {
496 None
497 }
498 }
499
500 (AstNode::Cte(cte), PathComponent::Field(field)) => match field {
502 AstField::Alias => None, AstField::CteInner => match &cte.cte_inner {
504 CteInner::Query(query) => Some(AstNode::Query(query.clone())),
505 CteInner::ChangeLog(_) => None, },
507 _ => None,
508 },
509
510 _ => {
511 tracing::debug!(
513 "get_child: No match for {:?} with component {:?}",
514 std::mem::discriminant(self),
515 component
516 );
517 None
518 }
519 }
520 }
521
522 pub fn set_child(
525 &self,
526 component: &PathComponent,
527 new_child: Option<AstNode>,
528 ) -> Option<AstNode> {
529 match (self, component) {
530 (AstNode::Statement(Statement::Query(_)), PathComponent::Field(field))
532 if *field == AstField::Query =>
533 {
534 if let Some(AstNode::Query(new_query)) = new_child {
535 Some(AstNode::Statement(Statement::Query(new_query)))
536 } else {
537 None
538 }
539 }
540
541 (
542 AstNode::Statement(Statement::CreateView {
543 name,
544 columns,
545 query: _,
546 or_replace,
547 materialized,
548 if_not_exists,
549 emit_mode,
550 with_options,
551 }),
552 PathComponent::Field(field),
553 ) => match field {
554 AstField::Query => {
555 if let Some(AstNode::Query(new_query)) = new_child {
556 let new_stmt = Statement::CreateView {
557 or_replace: *or_replace,
558 materialized: *materialized,
559 if_not_exists: *if_not_exists,
560 name: name.clone(),
561 columns: columns.clone(),
562 query: new_query,
563 emit_mode: emit_mode.clone(),
564 with_options: with_options.clone(),
565 };
566 Some(AstNode::Statement(new_stmt))
567 } else {
568 None
569 }
570 }
571 _ => None,
572 },
573
574 (AstNode::Query(query), PathComponent::Field(field)) => {
576 let mut new_query = (**query).clone();
577 match field {
578 AstField::Body => {
579 if let Some(AstNode::Select(select)) = new_child {
580 new_query.body = SetExpr::Select(select);
581 Some(AstNode::Query(Box::new(new_query)))
582 } else {
583 None
584 }
585 }
586 AstField::OrderBy => {
587 if let Some(AstNode::OrderByList(orders)) = new_child {
588 new_query.order_by = orders;
589 } else {
590 new_query.order_by = vec![];
591 }
592 Some(AstNode::Query(Box::new(new_query)))
593 }
594 AstField::Limit => {
595 new_query.limit = new_child.and_then(|n| match n {
596 AstNode::Expr(e) => Some(e),
597 _ => None,
598 });
599 Some(AstNode::Query(Box::new(new_query)))
600 }
601 AstField::With => {
602 new_query.with = new_child.and_then(|n| match n {
603 AstNode::With(w) => Some(w),
604 _ => None,
605 });
606 Some(AstNode::Query(Box::new(new_query)))
607 }
608 _ => None,
609 }
610 }
611
612 (AstNode::Select(select), PathComponent::Field(field)) => {
614 let mut new_select = (**select).clone();
615 match field {
616 AstField::Selection => {
617 new_select.selection = new_child.and_then(|n| match n {
618 AstNode::Expr(e) => Some(e),
619 _ => None,
620 });
621 }
622 AstField::Having => {
623 new_select.having = new_child.and_then(|n| match n {
624 AstNode::Expr(e) => Some(e),
625 _ => None,
626 });
627 }
628 AstField::Projection => {
629 if let Some(AstNode::SelectItemList(items)) = new_child {
630 new_select.projection = items;
631 } else {
632 new_select.projection = vec![];
634 }
635 }
636 AstField::From => {
637 if let Some(AstNode::TableList(tables)) = new_child {
638 new_select.from = tables;
639 } else {
640 new_select.from = vec![];
642 }
643 }
644 AstField::GroupBy => {
645 if let Some(AstNode::ExprList(exprs)) = new_child {
646 new_select.group_by = exprs;
647 } else {
648 new_select.group_by = vec![];
650 }
651 }
652 _ => return None,
653 }
654 Some(AstNode::Select(Box::new(new_select)))
655 }
656
657 (AstNode::SelectItemList(items), PathComponent::Index(idx)) => {
659 let mut new_items = items.clone();
660 if *idx < new_items.len() {
661 if let Some(AstNode::SelectItem(item)) = new_child {
662 new_items[*idx] = item;
663 } else {
664 new_items.remove(*idx);
665 }
666 Some(AstNode::SelectItemList(new_items))
667 } else {
668 None
669 }
670 }
671
672 (AstNode::ExprList(exprs), PathComponent::Index(idx)) => {
673 let mut new_exprs = exprs.clone();
674 if *idx < new_exprs.len() {
675 if let Some(AstNode::Expr(expr)) = new_child {
676 new_exprs[*idx] = expr;
677 } else {
678 new_exprs.remove(*idx);
679 }
680 Some(AstNode::ExprList(new_exprs))
681 } else {
682 None
683 }
684 }
685
686 (AstNode::TableList(tables), PathComponent::Index(idx)) => {
687 let mut new_tables = tables.clone();
688 if *idx < new_tables.len() {
689 if let Some(AstNode::TableWithJoins(table)) = new_child {
690 new_tables[*idx] = table;
691 } else {
692 new_tables.remove(*idx);
693 }
694 Some(AstNode::TableList(new_tables))
695 } else {
696 None
697 }
698 }
699
700 (AstNode::TableWithJoins(table_with_joins), PathComponent::Field(field)) => {
702 match field {
703 AstField::Relation => {
704 if let Some(AstNode::TableFactor(new_relation)) = new_child {
705 Some(AstNode::TableWithJoins(TableWithJoins {
706 relation: new_relation,
707 joins: table_with_joins.joins.clone(),
708 }))
709 } else {
710 None
711 }
712 }
713 AstField::Joins => {
714 if let Some(AstNode::JoinList(new_joins)) = new_child {
715 Some(AstNode::TableWithJoins(TableWithJoins {
716 relation: table_with_joins.relation.clone(),
717 joins: new_joins,
718 }))
719 } else {
720 Some(AstNode::TableWithJoins(TableWithJoins {
722 relation: table_with_joins.relation.clone(),
723 joins: vec![],
724 }))
725 }
726 }
727 _ => None,
728 }
729 }
730
731 (AstNode::Join(join), PathComponent::Field(field)) => match field {
733 AstField::Relation => {
734 if let Some(AstNode::TableFactor(new_relation)) = new_child {
735 Some(AstNode::Join(Join {
736 relation: new_relation,
737 join_operator: join.join_operator.clone(),
738 }))
739 } else {
740 None
741 }
742 }
743 _ => None,
744 },
745
746 (AstNode::TableFactor(table_factor), PathComponent::Field(field)) => {
748 match (table_factor, field) {
749 (TableFactor::Derived { lateral, alias, .. }, AstField::Subquery) => {
750 if let Some(AstNode::Query(new_subquery)) = new_child {
751 Some(AstNode::TableFactor(TableFactor::Derived {
752 lateral: *lateral,
753 subquery: new_subquery,
754 alias: alias.clone(),
755 }))
756 } else {
757 None
758 }
759 }
760 _ => None,
761 }
762 }
763
764 (AstNode::JoinList(joins), PathComponent::Index(idx)) => {
766 let mut new_joins = joins.clone();
767 if *idx < new_joins.len() {
768 if let Some(AstNode::Join(new_join)) = new_child {
769 new_joins[*idx] = new_join;
770 } else {
771 new_joins.remove(*idx);
772 }
773 Some(AstNode::JoinList(new_joins))
774 } else {
775 None
776 }
777 }
778
779 (AstNode::OrderByList(orders), PathComponent::Index(idx)) => {
780 let mut new_orders = orders.clone();
781 if *idx < new_orders.len() {
782 if let Some(AstNode::OrderByExpr(order)) = new_child {
783 new_orders[*idx] = order;
784 } else {
785 new_orders.remove(*idx);
786 }
787 Some(AstNode::OrderByList(new_orders))
788 } else {
789 None
790 }
791 }
792
793 (AstNode::Expr(expr), PathComponent::Field(field)) => match (expr, field) {
795 (Expr::BinaryOp { left: _, op, right }, AstField::Left) => {
796 if let Some(AstNode::Expr(new_left)) = new_child {
797 Some(AstNode::Expr(Expr::BinaryOp {
798 left: Box::new(new_left),
799 op: op.clone(),
800 right: right.clone(),
801 }))
802 } else {
803 None
804 }
805 }
806 (Expr::BinaryOp { left, op, right: _ }, AstField::Right) => {
807 if let Some(AstNode::Expr(new_right)) = new_child {
808 Some(AstNode::Expr(Expr::BinaryOp {
809 left: left.clone(),
810 op: op.clone(),
811 right: Box::new(new_right),
812 }))
813 } else {
814 None
815 }
816 }
817 (Expr::Nested(_), AstField::Inner) => {
818 if let Some(AstNode::Expr(new_inner)) = new_child {
819 Some(AstNode::Expr(Expr::Nested(Box::new(new_inner))))
820 } else {
821 None
822 }
823 }
824 _ => None,
825 },
826
827 (AstNode::SelectItem(select_item), PathComponent::Field(field)) => {
829 match (select_item, field) {
830 (SelectItem::UnnamedExpr(_), AstField::Expr) => {
831 if let Some(AstNode::Expr(new_expr)) = new_child {
832 Some(AstNode::SelectItem(SelectItem::UnnamedExpr(new_expr)))
833 } else {
834 None
835 }
836 }
837 (SelectItem::ExprWithAlias { alias, .. }, AstField::Expr) => {
838 if let Some(AstNode::Expr(new_expr)) = new_child {
839 Some(AstNode::SelectItem(SelectItem::ExprWithAlias {
840 expr: new_expr,
841 alias: alias.clone(),
842 }))
843 } else {
844 None
845 }
846 }
847 _ => None,
848 }
849 }
850
851 (AstNode::OrderByExpr(order_by), PathComponent::Field(field)) => match field {
853 AstField::Expr => {
854 if let Some(AstNode::Expr(new_expr)) = new_child {
855 Some(AstNode::OrderByExpr(OrderByExpr {
856 expr: new_expr,
857 asc: order_by.asc,
858 nulls_first: order_by.nulls_first,
859 }))
860 } else {
861 None
862 }
863 }
864 _ => None,
865 },
866
867 (AstNode::With(with_clause), PathComponent::Field(field)) => match field {
869 AstField::CteTable => {
870 if let Some(AstNode::CteList(new_ctes)) = new_child {
871 let mut new_with = with_clause.clone();
872 new_with.cte_tables = new_ctes;
873 Some(AstNode::With(new_with))
874 } else {
875 let mut new_with = with_clause.clone();
877 new_with.cte_tables = vec![];
878 Some(AstNode::With(new_with))
879 }
880 }
881 _ => None,
882 },
883
884 (AstNode::CteList(ctes), PathComponent::Index(idx)) => {
886 if let Some(AstNode::Cte(new_cte)) = new_child {
887 if *idx < ctes.len() {
888 let mut new_ctes = ctes.clone();
889 new_ctes[*idx] = new_cte;
890 Some(AstNode::CteList(new_ctes))
891 } else {
892 None
893 }
894 } else {
895 None
896 }
897 }
898
899 (AstNode::Cte(cte), PathComponent::Field(field)) => match field {
901 AstField::CteInner => {
902 if let Some(AstNode::Query(new_query)) = new_child {
903 let mut new_cte = cte.clone();
904 new_cte.cte_inner = CteInner::Query(new_query);
905 Some(AstNode::Cte(new_cte))
906 } else {
907 None
908 }
909 }
910 _ => None,
911 },
912
913 _ => {
914 tracing::debug!(
916 "set_child: No match for {} ({:?}) with component {:?}",
917 AstNodeType::from(self),
918 std::mem::discriminant(self),
919 component
920 );
921 None
922 }
923 }
924 }
925
926 pub fn to_statement(&self) -> Option<Statement> {
928 match self {
929 AstNode::Statement(stmt) => Some(stmt.clone()),
930 _ => None,
931 }
932 }
933}
934
935pub fn get_node_at_path(root: &AstNode, path: &AstPath) -> Option<AstNode> {
938 let mut current = root.clone();
939 for component in path {
940 current = current.get_child(component)?;
941 }
942 Some(current)
943}
944
945pub fn set_node_at_path(
948 root: &AstNode,
949 path: &AstPath,
950 new_node: Option<AstNode>,
951) -> Option<AstNode> {
952 if path.is_empty() {
953 return new_node;
954 }
955
956 let result = root.clone();
957 let mut current_path = Vec::new();
958
959 for component in &path[..path.len() - 1] {
961 current_path.push(component.clone());
962 }
963
964 let parent = get_node_at_path(&result, ¤t_path)?;
966
967 let modified_parent = parent.set_child(&path[path.len() - 1], new_node)?;
969
970 if current_path.is_empty() {
972 Some(modified_parent)
973 } else {
974 set_node_at_path(&result, ¤t_path, Some(modified_parent))
975 }
976}
977
978fn explore_child_field(
980 node: &AstNode,
981 field: AstField,
982 current_path: &AstPath,
983 paths: &mut Vec<AstPath>,
984) {
985 let field_component = PathComponent::field(field);
986 let child_path = [current_path.clone(), vec![field_component.clone()]].concat();
987 let relative_path = vec![field_component];
988
989 if let Some(child_node) = get_node_at_path(node, &relative_path) {
990 let child_paths = enumerate_reduction_paths(&child_node, child_path);
992 paths.extend(child_paths);
993 }
994}
995
996fn path_depth(path: &AstPath) -> usize {
999 path.len()
1000}
1001
1002pub fn enumerate_reduction_paths(node: &AstNode, current_path: AstPath) -> Vec<AstPath> {
1005 let mut paths = vec![current_path.clone()];
1006
1007 tracing::debug!(
1008 "Enumerating paths for node {:?} at path {}",
1009 get_node_type_name(node),
1010 display_ast_path(¤t_path)
1011 );
1012
1013 match node {
1014 AstNode::Statement(Statement::Query(_)) => {
1015 explore_child_field(node, AstField::Query, ¤t_path, &mut paths);
1016 }
1017
1018 AstNode::Statement(Statement::CreateView { .. }) => {
1019 explore_child_field(node, AstField::Query, ¤t_path, &mut paths);
1020 }
1021
1022 AstNode::Query(_query) => {
1023 explore_child_field(node, AstField::Body, ¤t_path, &mut paths);
1024 explore_child_field(node, AstField::With, ¤t_path, &mut paths);
1025 explore_child_field(node, AstField::OrderBy, ¤t_path, &mut paths);
1026 }
1027
1028 AstNode::Select(_) => {
1029 explore_child_field(node, AstField::Projection, ¤t_path, &mut paths);
1030 explore_child_field(node, AstField::From, ¤t_path, &mut paths);
1031 explore_child_field(node, AstField::GroupBy, ¤t_path, &mut paths);
1032
1033 let selection_path = [current_path.clone(), vec![PathComponent::selection()]].concat();
1035 if get_node_at_path(node, &vec![PathComponent::selection()]).is_some() {
1036 paths.push(selection_path);
1037 }
1038
1039 let having_path = [current_path.clone(), vec![PathComponent::having()]].concat();
1040 if get_node_at_path(node, &vec![PathComponent::having()]).is_some() {
1041 paths.push(having_path);
1042 }
1043 }
1044
1045 AstNode::SelectItemList(items) => {
1047 for i in 0..items.len() {
1048 let item_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1049 paths.push(item_path);
1050 }
1051 }
1052
1053 AstNode::ExprList(exprs) => {
1054 for i in 0..exprs.len() {
1055 let expr_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1056 paths.push(expr_path.clone());
1057 let relative_path = vec![PathComponent::Index(i)];
1059 if let Some(expr_node) = get_node_at_path(node, &relative_path) {
1060 paths.extend(enumerate_reduction_paths(&expr_node, expr_path));
1061 }
1062 }
1063 }
1064
1065 AstNode::TableList(tables) => {
1066 for i in 0..tables.len() {
1067 let table_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1068 paths.push(table_path.clone());
1069 if let Some(table_node) = node.get_child(&PathComponent::Index(i)) {
1071 paths.extend(enumerate_reduction_paths(&table_node, table_path));
1072 }
1073 }
1074 }
1075
1076 AstNode::TableWithJoins(_) => {
1078 explore_child_field(node, AstField::Relation, ¤t_path, &mut paths);
1079 explore_child_field(node, AstField::Joins, ¤t_path, &mut paths);
1080 }
1081
1082 AstNode::Join(_) => {
1084 explore_child_field(node, AstField::Relation, ¤t_path, &mut paths);
1085 }
1086
1087 AstNode::TableFactor(_) => {
1089 explore_child_field(node, AstField::Subquery, ¤t_path, &mut paths);
1091 }
1092
1093 AstNode::JoinList(joins) => {
1095 for i in 0..joins.len() {
1096 let join_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1097 paths.push(join_path.clone());
1098 if let Some(join_node) = node.get_child(&PathComponent::Index(i)) {
1100 paths.extend(enumerate_reduction_paths(&join_node, join_path));
1101 }
1102 }
1103 }
1104
1105 AstNode::OrderByList(orders) => {
1106 for i in 0..orders.len() {
1107 let order_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1108 paths.push(order_path);
1109 }
1110 }
1111
1112 AstNode::Expr(expr) => match expr {
1114 Expr::BinaryOp { .. } => {
1115 let left_path = [current_path.clone(), vec![PathComponent::left()]].concat();
1116 let right_path = [current_path.clone(), vec![PathComponent::right()]].concat();
1117 paths.push(left_path);
1118 paths.push(right_path);
1119 }
1120 Expr::Case {
1121 operand,
1122 else_result,
1123 ..
1124 } => {
1125 if operand.is_some() {
1126 let operand_path =
1127 [current_path.clone(), vec![PathComponent::operand()]].concat();
1128 paths.push(operand_path);
1129 }
1130 if else_result.is_some() {
1131 let else_path =
1132 [current_path.clone(), vec![PathComponent::else_result()]].concat();
1133 paths.push(else_path);
1134 }
1135 }
1136 Expr::Exists(_) => {
1137 explore_child_field(node, AstField::Subquery, ¤t_path, &mut paths);
1138 }
1139 Expr::Subquery(_) => {
1140 explore_child_field(node, AstField::Subquery, ¤t_path, &mut paths);
1141 }
1142 _ => {}
1143 },
1144
1145 AstNode::With(_) => {
1147 explore_child_field(node, AstField::CteTable, ¤t_path, &mut paths);
1148 }
1149
1150 AstNode::CteList(ctes) => {
1152 for i in 0..ctes.len() {
1153 let cte_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1154 paths.push(cte_path.clone());
1155
1156 if let Some(cte_node) = get_node_at_path(node, &vec![PathComponent::Index(i)]) {
1158 paths.extend(enumerate_reduction_paths(&cte_node, cte_path));
1159 }
1160 }
1161 }
1162
1163 AstNode::Cte(_) => {
1165 explore_child_field(node, AstField::CteInner, ¤t_path, &mut paths);
1166 }
1167
1168 _ => {}
1169 }
1170
1171 paths.sort_by_key(path_depth);
1175
1176 paths
1177}
1178
1179pub fn statement_to_ast_node(stmt: &Statement) -> AstNode {
1181 AstNode::Statement(stmt.clone())
1182}
1183
1184pub fn ast_node_to_statement(node: &AstNode) -> Option<Statement> {
1186 match node {
1187 AstNode::Statement(stmt) => Some(stmt.clone()),
1188 AstNode::Query(query) => Some(Statement::Query(Box::new(query.as_ref().clone()))),
1189 _ => None,
1190 }
1191}
1192
1193pub fn get_node_type_name(node: &AstNode) -> String {
1195 AstNodeType::from(node).to_string()
1196}
1197
1198#[cfg(test)]
1199mod tests {
1200 use risingwave_sqlparser::parser::Parser;
1201
1202 use super::*;
1203
1204 #[test]
1205 fn test_path_enumeration() {
1206 let sql = "SELECT a FROM b;";
1207 let parsed = Parser::parse_sql(sql).expect("Failed to parse SQL");
1208 let stmt = &parsed[0];
1209 let ast_node = statement_to_ast_node(stmt);
1210
1211 let paths = enumerate_reduction_paths(&ast_node, vec![]);
1212
1213 assert!(paths.len() >= 3);
1215 println!("Found {} paths for simple query", paths.len());
1216 }
1217
1218 #[test]
1219 fn test_create_materialized_view() {
1220 let sql = "CREATE MATERIALIZED VIEW stream_query AS SELECT min((tumble_0.c5 + tumble_0.c5)) AS col_0, false AS col_1, TIME '07:30:48' AS col_2, tumble_0.c14 AS col_3 FROM tumble(alltypes1, alltypes1.c11, INTERVAL '10') AS tumble_0 WHERE tumble_0.c1 GROUP BY tumble_0.c10, tumble_0.c9, tumble_0.c13, tumble_0.c1, tumble_0.c16, tumble_0.c14, tumble_0.c5, tumble_0.c8;";
1221 let parsed = Parser::parse_sql(sql).expect("Failed to parse SQL");
1222 let stmt = &parsed[0];
1223 let ast_node = statement_to_ast_node(stmt);
1224
1225 let paths = enumerate_reduction_paths(&ast_node, vec![]);
1226
1227 println!("Found {} paths for CREATE MATERIALIZED VIEW", paths.len());
1228 for (i, path) in paths.iter().enumerate() {
1229 println!("Path {}: {}", i, display_ast_path(path));
1230 }
1231
1232 assert!(!paths.is_empty());
1234
1235 assert!(get_node_at_path(&ast_node, &paths[0]).is_some());
1237 }
1238}