risingwave_sqlsmith/sqlreduce/
path.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Path-based AST navigation for SQL reduction.
16//!
17//! This module provides utilities for navigating and modifying SQL ASTs using
18//! path-based addressing for precise AST manipulation.
19
20use std::fmt;
21
22use risingwave_sqlparser::ast::*;
23use strum::EnumDiscriminants;
24
25/// Represents all possible AST field names that can be navigated.
26/// This provides compile-time safety for field access.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum AstField {
29    // Statement fields
30    Query,
31    Name,
32    Columns,
33
34    // Query fields
35    Body,
36    With,
37    OrderBy,
38    Limit,
39    Offset,
40
41    // Select fields
42    Projection,
43    Selection,
44    From,
45    GroupBy,
46    Having,
47    Distinct,
48
49    // Expression fields
50    Left,
51    Right,
52    Operand,
53    ElseResult,
54    Subquery,
55    Inner,
56    Expr,
57    Low,
58    High,
59
60    // TableWithJoins fields
61    Relation,
62    Joins,
63
64    // Join fields
65    JoinOperator,
66
67    // SelectItem fields
68    Alias,
69
70    // OrderByExpr fields
71    Asc,
72    NullsFirst,
73
74    // With clause fields
75    CteTable,
76    Recursive,
77
78    // CTE fields
79    CteInner,
80
81    // TableFactor fields
82    Lateral,
83}
84
85impl fmt::Display for AstField {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        let s = match self {
88            AstField::Query => "query",
89            AstField::Name => "name",
90            AstField::Columns => "columns",
91            AstField::Body => "body",
92            AstField::With => "with",
93            AstField::OrderBy => "order_by",
94            AstField::Limit => "limit",
95            AstField::Offset => "offset",
96            AstField::Projection => "projection",
97            AstField::Selection => "selection",
98            AstField::From => "from",
99            AstField::GroupBy => "group_by",
100            AstField::Having => "having",
101            AstField::Distinct => "distinct",
102            AstField::Left => "left",
103            AstField::Right => "right",
104            AstField::Operand => "operand",
105            AstField::ElseResult => "else_result",
106            AstField::Subquery => "subquery",
107            AstField::Inner => "inner",
108            AstField::Expr => "expr",
109            AstField::Low => "low",
110            AstField::High => "high",
111            AstField::Relation => "relation",
112            AstField::Joins => "joins",
113            AstField::JoinOperator => "join_operator",
114            AstField::Alias => "alias",
115            AstField::Asc => "asc",
116            AstField::NullsFirst => "nulls_first",
117            AstField::CteTable => "cte_tables",
118            AstField::Recursive => "recursive",
119            AstField::CteInner => "cte_inner",
120            AstField::Lateral => "lateral",
121        };
122        write!(f, "{}", s)
123    }
124}
125
126/// Represents a path component in an AST navigation path.
127/// Components that make up a path through the AST.
128/// Enables precise navigation to any AST node.
129#[derive(Debug, Clone, PartialEq, Eq)]
130pub enum PathComponent {
131    /// Field access by AST field enum (type-safe)
132    Field(AstField),
133    /// Array/Vec index access
134    Index(usize),
135}
136
137impl PathComponent {
138    /// Create a Field `PathComponent` from `AstField` enum (preferred)
139    pub fn field(field: AstField) -> Self {
140        PathComponent::Field(field)
141    }
142
143    /// Get the field name as owned String for compatibility
144    pub fn field_name(&self) -> Option<String> {
145        match self {
146            PathComponent::Field(field) => Some(field.to_string()),
147            PathComponent::Index(_) => None,
148        }
149    }
150
151    /// Get the `AstField` enum if this is a field
152    pub fn as_ast_field(&self) -> Option<&AstField> {
153        match self {
154            PathComponent::Field(field) => Some(field),
155            PathComponent::Index(_) => None,
156        }
157    }
158
159    // Convenience constructors for common fields
160    pub fn query() -> Self {
161        Self::field(AstField::Query)
162    }
163
164    pub fn body() -> Self {
165        Self::field(AstField::Body)
166    }
167
168    pub fn selection() -> Self {
169        Self::field(AstField::Selection)
170    }
171
172    pub fn projection() -> Self {
173        Self::field(AstField::Projection)
174    }
175
176    pub fn from_clause() -> Self {
177        Self::field(AstField::From)
178    }
179
180    pub fn group_by() -> Self {
181        Self::field(AstField::GroupBy)
182    }
183
184    pub fn having() -> Self {
185        Self::field(AstField::Having)
186    }
187
188    pub fn with_clause() -> Self {
189        Self::field(AstField::With)
190    }
191
192    pub fn order_by() -> Self {
193        Self::field(AstField::OrderBy)
194    }
195
196    pub fn left() -> Self {
197        Self::field(AstField::Left)
198    }
199
200    pub fn right() -> Self {
201        Self::field(AstField::Right)
202    }
203
204    pub fn operand() -> Self {
205        Self::field(AstField::Operand)
206    }
207
208    pub fn else_result() -> Self {
209        Self::field(AstField::ElseResult)
210    }
211
212    pub fn relation() -> Self {
213        Self::field(AstField::Relation)
214    }
215
216    pub fn joins() -> Self {
217        Self::field(AstField::Joins)
218    }
219
220    pub fn subquery() -> Self {
221        Self::field(AstField::Subquery)
222    }
223
224    pub fn cte_tables() -> Self {
225        Self::field(AstField::CteTable)
226    }
227
228    pub fn cte_inner() -> Self {
229        Self::field(AstField::CteInner)
230    }
231}
232
233impl fmt::Display for PathComponent {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        match self {
236            PathComponent::Field(field) => write!(f, ".{}", field),
237            PathComponent::Index(idx) => write!(f, "[{}]", idx),
238        }
239    }
240}
241
242/// A path through the AST for precise node identification.
243/// This allows us to precisely identify and modify any node in the tree.
244pub type AstPath = Vec<PathComponent>;
245
246// Note: Display implementation moved to helper function to avoid orphan rule
247pub fn display_ast_path(path: &AstPath) -> String {
248    path.iter().map(|c| c.to_string()).collect::<String>()
249}
250
251/// Represents a node in the AST that can be navigated and modified.
252/// This is a simplified representation focusing on the most commonly
253/// reduced SQL constructs.
254#[derive(Debug, Clone, EnumDiscriminants)]
255#[strum_discriminants(derive(strum::Display))]
256#[strum_discriminants(name(AstNodeType))]
257pub enum AstNode {
258    Statement(Statement),
259    Query(Box<Query>),
260    Select(Box<Select>),
261    Expr(Expr),
262    SelectItem(SelectItem),
263    TableWithJoins(TableWithJoins),
264    Join(Join),
265    TableFactor(TableFactor),
266    OrderByExpr(OrderByExpr),
267    With(With),
268    Cte(Cte),
269    ExprList(Vec<Expr>),
270    SelectItemList(Vec<SelectItem>),
271    TableList(Vec<TableWithJoins>),
272    JoinList(Vec<Join>),
273    OrderByList(Vec<OrderByExpr>),
274    CteList(Vec<Cte>),
275    Option(Option<Box<AstNode>>),
276}
277
278impl AstNode {
279    /// Navigate to a child node using a path component.
280    pub fn get_child(&self, component: &PathComponent) -> Option<AstNode> {
281        match (self, component) {
282            // Statement navigation
283            (AstNode::Statement(Statement::Query(query)), PathComponent::Field(field))
284                if *field == AstField::Query =>
285            {
286                Some(AstNode::Query(query.clone()))
287            }
288
289            (
290                AstNode::Statement(Statement::CreateView { query, .. }),
291                PathComponent::Field(field),
292            ) if *field == AstField::Query => Some(AstNode::Query(query.clone())),
293
294            // More CreateView field access
295            (
296                AstNode::Statement(Statement::CreateView {
297                    name: _,
298                    columns: _,
299                    ..
300                }),
301                PathComponent::Field(field),
302            ) => match field {
303                AstField::Name => None,    // ObjectName is complex, skip for now
304                AstField::Columns => None, // Column list is complex, skip for now
305                _ => None,
306            },
307
308            // Query navigation - enhanced
309            (AstNode::Query(query), PathComponent::Field(field)) => match field {
310                AstField::Body => match &query.body {
311                    SetExpr::Select(select) => Some(AstNode::Select(select.clone())),
312                    SetExpr::Query(subquery) => Some(AstNode::Query(subquery.clone())),
313                    SetExpr::SetOperation { left, .. } => {
314                        // For SetOperation, recursively handle the left side
315                        match left.as_ref() {
316                            SetExpr::Select(select) => Some(AstNode::Select(select.clone())),
317                            SetExpr::Query(subquery) => Some(AstNode::Query(subquery.clone())),
318                            _ => None,
319                        }
320                    }
321                    _ => None,
322                },
323                AstField::With => query.with.as_ref().map(|w| AstNode::With(w.clone())),
324                AstField::OrderBy => {
325                    if query.order_by.is_empty() {
326                        None
327                    } else {
328                        Some(AstNode::OrderByList(query.order_by.clone()))
329                    }
330                }
331                AstField::Limit => query.limit.as_ref().map(|e| AstNode::Expr(e.clone())),
332                AstField::Offset => None, // offset is Option<String>, not navigable as expression
333                _ => None,
334            },
335
336            // Select navigation - enhanced
337            (AstNode::Select(select), PathComponent::Field(field)) => match field {
338                AstField::Projection => {
339                    if select.projection.is_empty() {
340                        None
341                    } else {
342                        Some(AstNode::SelectItemList(select.projection.clone()))
343                    }
344                }
345                AstField::Selection => select.selection.as_ref().map(|e| AstNode::Expr(e.clone())),
346                AstField::From => {
347                    if select.from.is_empty() {
348                        None
349                    } else {
350                        Some(AstNode::TableList(select.from.clone()))
351                    }
352                }
353                AstField::GroupBy => {
354                    if select.group_by.is_empty() {
355                        None
356                    } else {
357                        Some(AstNode::ExprList(select.group_by.clone()))
358                    }
359                }
360                AstField::Having => select.having.as_ref().map(|e| AstNode::Expr(e.clone())),
361                AstField::Distinct => None, // Distinct is an enum, handle separately if needed
362                _ => None,
363            },
364
365            // List navigation
366            (AstNode::SelectItemList(items), PathComponent::Index(idx)) => items
367                .get(*idx)
368                .map(|item| AstNode::SelectItem(item.clone())),
369            (AstNode::ExprList(exprs), PathComponent::Index(idx)) => {
370                exprs.get(*idx).map(|expr| AstNode::Expr(expr.clone()))
371            }
372            (AstNode::TableList(tables), PathComponent::Index(idx)) => tables
373                .get(*idx)
374                .map(|table| AstNode::TableWithJoins(table.clone())),
375            (AstNode::OrderByList(orders), PathComponent::Index(idx)) => orders
376                .get(*idx)
377                .map(|order| AstNode::OrderByExpr(order.clone())),
378
379            // Expression navigation (for pullup operations) - enhanced
380            (AstNode::Expr(expr), PathComponent::Field(field)) => match (expr, field) {
381                (Expr::BinaryOp { left, .. }, AstField::Left) => Some(AstNode::Expr(*left.clone())),
382                (Expr::BinaryOp { right, .. }, AstField::Right) => {
383                    Some(AstNode::Expr(*right.clone()))
384                }
385                (Expr::Case { operand, .. }, AstField::Operand) => {
386                    operand.as_ref().map(|e| AstNode::Expr(*e.clone()))
387                }
388                (Expr::Case { else_result, .. }, AstField::ElseResult) => {
389                    else_result.as_ref().map(|e| AstNode::Expr(*e.clone()))
390                }
391                (Expr::Exists(subquery), AstField::Subquery) => {
392                    Some(AstNode::Query(subquery.clone()))
393                }
394                (Expr::Subquery(subquery), AstField::Subquery) => {
395                    Some(AstNode::Query(subquery.clone()))
396                }
397                (Expr::Function(_func), AstField::Name) => None, // ObjectName is complex
398                (Expr::Nested(inner), AstField::Inner) => Some(AstNode::Expr(*inner.clone())),
399                (Expr::UnaryOp { expr, .. }, AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
400                (Expr::Cast { expr, .. }, AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
401                (Expr::IsNull(expr), AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
402                (Expr::IsNotNull(expr), AstField::Expr) => Some(AstNode::Expr(*expr.clone())),
403                (
404                    Expr::Between {
405                        expr,
406                        low: _,
407                        high: _,
408                        ..
409                    },
410                    AstField::Expr,
411                ) => Some(AstNode::Expr(*expr.clone())),
412                (Expr::Between { low, .. }, AstField::Low) => Some(AstNode::Expr(*low.clone())),
413                (Expr::Between { high, .. }, AstField::High) => Some(AstNode::Expr(*high.clone())),
414                _ => None,
415            },
416
417            // TableWithJoins navigation
418            (AstNode::TableWithJoins(table_with_joins), PathComponent::Field(field)) => match field
419            {
420                AstField::Relation => Some(AstNode::TableFactor(table_with_joins.relation.clone())),
421                AstField::Joins => {
422                    if !table_with_joins.joins.is_empty() {
423                        Some(AstNode::JoinList(table_with_joins.joins.clone()))
424                    } else {
425                        None
426                    }
427                }
428                _ => None,
429            },
430
431            // Join navigation
432            (AstNode::Join(join), PathComponent::Field(field)) => match field {
433                AstField::Relation => Some(AstNode::TableFactor(join.relation.clone())),
434                AstField::JoinOperator => None, // JoinOperator is simple enum, skip navigation
435                _ => None,
436            },
437
438            // TableFactor navigation
439            (AstNode::TableFactor(table_factor), PathComponent::Field(field)) => {
440                match (table_factor, field) {
441                    (TableFactor::Table { .. }, _) => None, // Table references are terminal
442                    (TableFactor::Derived { subquery, .. }, AstField::Subquery) => {
443                        Some(AstNode::Query(subquery.clone()))
444                    }
445                    (TableFactor::TableFunction { .. }, _) => None, // Function calls are complex
446                    _ => None,
447                }
448            }
449
450            // JoinList navigation (for Vec<Join>)
451            (AstNode::JoinList(joins), PathComponent::Index(idx)) => {
452                joins.get(*idx).map(|join| AstNode::Join(join.clone()))
453            }
454
455            // SelectItem navigation
456            (AstNode::SelectItem(select_item), PathComponent::Field(field)) => {
457                match (select_item, field) {
458                    (SelectItem::UnnamedExpr(expr), AstField::Expr) => {
459                        Some(AstNode::Expr(expr.clone()))
460                    }
461                    (SelectItem::ExprWithAlias { expr, .. }, AstField::Expr) => {
462                        Some(AstNode::Expr(expr.clone()))
463                    }
464                    (SelectItem::QualifiedWildcard(..), _) => None,
465                    (SelectItem::Wildcard(..), _) => None,
466                    _ => None,
467                }
468            }
469
470            // OrderByExpr navigation
471            (AstNode::OrderByExpr(order_by), PathComponent::Field(field)) => match field {
472                AstField::Expr => Some(AstNode::Expr(order_by.expr.clone())),
473                AstField::Asc => None,        // Boolean, not navigable
474                AstField::NullsFirst => None, // Option<bool>, not navigable
475                _ => None,
476            },
477
478            // With clause navigation
479            (AstNode::With(with_clause), PathComponent::Field(field)) => match field {
480                AstField::CteTable => {
481                    if with_clause.cte_tables.is_empty() {
482                        None
483                    } else {
484                        Some(AstNode::CteList(with_clause.cte_tables.clone()))
485                    }
486                }
487                AstField::Recursive => None, // Boolean, not navigable
488                _ => None,
489            },
490
491            // CTE list navigation
492            (AstNode::CteList(ctes), PathComponent::Index(idx)) => {
493                if *idx < ctes.len() {
494                    Some(AstNode::Cte(ctes[*idx].clone()))
495                } else {
496                    None
497                }
498            }
499
500            // CTE navigation
501            (AstNode::Cte(cte), PathComponent::Field(field)) => match field {
502                AstField::Alias => None, // TableAlias is complex, skip for now
503                AstField::CteInner => match &cte.cte_inner {
504                    CteInner::Query(query) => Some(AstNode::Query(query.clone())),
505                    CteInner::ChangeLog(_) => None, // ObjectName is complex, skip for now
506                },
507                _ => None,
508            },
509
510            _ => {
511                // Add debug logging for unmatched cases
512                tracing::debug!(
513                    "get_child: No match for {:?} with component {:?}",
514                    std::mem::discriminant(self),
515                    component
516                );
517                None
518            }
519        }
520    }
521
522    /// Set a child node using a path component.
523    /// Returns a new `AstNode` with the modification applied.
524    pub fn set_child(
525        &self,
526        component: &PathComponent,
527        new_child: Option<AstNode>,
528    ) -> Option<AstNode> {
529        match (self, component) {
530            // Handle Statement::Query (root query statements)
531            (AstNode::Statement(Statement::Query(_)), PathComponent::Field(field))
532                if *field == AstField::Query =>
533            {
534                if let Some(AstNode::Query(new_query)) = new_child {
535                    Some(AstNode::Statement(Statement::Query(new_query)))
536                } else {
537                    None
538                }
539            }
540
541            (
542                AstNode::Statement(Statement::CreateView {
543                    name,
544                    columns,
545                    query: _,
546                    or_replace,
547                    materialized,
548                    if_not_exists,
549                    emit_mode,
550                    with_options,
551                }),
552                PathComponent::Field(field),
553            ) => match field {
554                AstField::Query => {
555                    if let Some(AstNode::Query(new_query)) = new_child {
556                        let new_stmt = Statement::CreateView {
557                            or_replace: *or_replace,
558                            materialized: *materialized,
559                            if_not_exists: *if_not_exists,
560                            name: name.clone(),
561                            columns: columns.clone(),
562                            query: new_query,
563                            emit_mode: emit_mode.clone(),
564                            with_options: with_options.clone(),
565                        };
566                        Some(AstNode::Statement(new_stmt))
567                    } else {
568                        None
569                    }
570                }
571                _ => None,
572            },
573
574            // Query field modifications
575            (AstNode::Query(query), PathComponent::Field(field)) => {
576                let mut new_query = (**query).clone();
577                match field {
578                    AstField::Body => {
579                        if let Some(AstNode::Select(select)) = new_child {
580                            new_query.body = SetExpr::Select(select);
581                            Some(AstNode::Query(Box::new(new_query)))
582                        } else {
583                            None
584                        }
585                    }
586                    AstField::OrderBy => {
587                        if let Some(AstNode::OrderByList(orders)) = new_child {
588                            new_query.order_by = orders;
589                        } else {
590                            new_query.order_by = vec![];
591                        }
592                        Some(AstNode::Query(Box::new(new_query)))
593                    }
594                    AstField::Limit => {
595                        new_query.limit = new_child.and_then(|n| match n {
596                            AstNode::Expr(e) => Some(e),
597                            _ => None,
598                        });
599                        Some(AstNode::Query(Box::new(new_query)))
600                    }
601                    AstField::With => {
602                        new_query.with = new_child.and_then(|n| match n {
603                            AstNode::With(w) => Some(w),
604                            _ => None,
605                        });
606                        Some(AstNode::Query(Box::new(new_query)))
607                    }
608                    _ => None,
609                }
610            }
611
612            // Select field modifications
613            (AstNode::Select(select), PathComponent::Field(field)) => {
614                let mut new_select = (**select).clone();
615                match field {
616                    AstField::Selection => {
617                        new_select.selection = new_child.and_then(|n| match n {
618                            AstNode::Expr(e) => Some(e),
619                            _ => None,
620                        });
621                    }
622                    AstField::Having => {
623                        new_select.having = new_child.and_then(|n| match n {
624                            AstNode::Expr(e) => Some(e),
625                            _ => None,
626                        });
627                    }
628                    AstField::Projection => {
629                        if let Some(AstNode::SelectItemList(items)) = new_child {
630                            new_select.projection = items;
631                        } else {
632                            // Remove projection by setting to empty vec (SELECT without columns is invalid, but we try it)
633                            new_select.projection = vec![];
634                        }
635                    }
636                    AstField::From => {
637                        if let Some(AstNode::TableList(tables)) = new_child {
638                            new_select.from = tables;
639                        } else {
640                            // Remove FROM clause by setting to empty vec
641                            new_select.from = vec![];
642                        }
643                    }
644                    AstField::GroupBy => {
645                        if let Some(AstNode::ExprList(exprs)) = new_child {
646                            new_select.group_by = exprs;
647                        } else {
648                            // Remove GROUP BY by setting to empty vec
649                            new_select.group_by = vec![];
650                        }
651                    }
652                    _ => return None,
653                }
654                Some(AstNode::Select(Box::new(new_select)))
655            }
656
657            // List modifications
658            (AstNode::SelectItemList(items), PathComponent::Index(idx)) => {
659                let mut new_items = items.clone();
660                if *idx < new_items.len() {
661                    if let Some(AstNode::SelectItem(item)) = new_child {
662                        new_items[*idx] = item;
663                    } else {
664                        new_items.remove(*idx);
665                    }
666                    Some(AstNode::SelectItemList(new_items))
667                } else {
668                    None
669                }
670            }
671
672            (AstNode::ExprList(exprs), PathComponent::Index(idx)) => {
673                let mut new_exprs = exprs.clone();
674                if *idx < new_exprs.len() {
675                    if let Some(AstNode::Expr(expr)) = new_child {
676                        new_exprs[*idx] = expr;
677                    } else {
678                        new_exprs.remove(*idx);
679                    }
680                    Some(AstNode::ExprList(new_exprs))
681                } else {
682                    None
683                }
684            }
685
686            (AstNode::TableList(tables), PathComponent::Index(idx)) => {
687                let mut new_tables = tables.clone();
688                if *idx < new_tables.len() {
689                    if let Some(AstNode::TableWithJoins(table)) = new_child {
690                        new_tables[*idx] = table;
691                    } else {
692                        new_tables.remove(*idx);
693                    }
694                    Some(AstNode::TableList(new_tables))
695                } else {
696                    None
697                }
698            }
699
700            // TableWithJoins modifications
701            (AstNode::TableWithJoins(table_with_joins), PathComponent::Field(field)) => {
702                match field {
703                    AstField::Relation => {
704                        if let Some(AstNode::TableFactor(new_relation)) = new_child {
705                            Some(AstNode::TableWithJoins(TableWithJoins {
706                                relation: new_relation,
707                                joins: table_with_joins.joins.clone(),
708                            }))
709                        } else {
710                            None
711                        }
712                    }
713                    AstField::Joins => {
714                        if let Some(AstNode::JoinList(new_joins)) = new_child {
715                            Some(AstNode::TableWithJoins(TableWithJoins {
716                                relation: table_with_joins.relation.clone(),
717                                joins: new_joins,
718                            }))
719                        } else {
720                            // Allow removing joins by setting to empty list
721                            Some(AstNode::TableWithJoins(TableWithJoins {
722                                relation: table_with_joins.relation.clone(),
723                                joins: vec![],
724                            }))
725                        }
726                    }
727                    _ => None,
728                }
729            }
730
731            // Join modifications
732            (AstNode::Join(join), PathComponent::Field(field)) => match field {
733                AstField::Relation => {
734                    if let Some(AstNode::TableFactor(new_relation)) = new_child {
735                        Some(AstNode::Join(Join {
736                            relation: new_relation,
737                            join_operator: join.join_operator.clone(),
738                        }))
739                    } else {
740                        None
741                    }
742                }
743                _ => None,
744            },
745
746            // TableFactor modifications
747            (AstNode::TableFactor(table_factor), PathComponent::Field(field)) => {
748                match (table_factor, field) {
749                    (TableFactor::Derived { lateral, alias, .. }, AstField::Subquery) => {
750                        if let Some(AstNode::Query(new_subquery)) = new_child {
751                            Some(AstNode::TableFactor(TableFactor::Derived {
752                                lateral: *lateral,
753                                subquery: new_subquery,
754                                alias: alias.clone(),
755                            }))
756                        } else {
757                            None
758                        }
759                    }
760                    _ => None,
761                }
762            }
763
764            // JoinList modifications
765            (AstNode::JoinList(joins), PathComponent::Index(idx)) => {
766                let mut new_joins = joins.clone();
767                if *idx < new_joins.len() {
768                    if let Some(AstNode::Join(new_join)) = new_child {
769                        new_joins[*idx] = new_join;
770                    } else {
771                        new_joins.remove(*idx);
772                    }
773                    Some(AstNode::JoinList(new_joins))
774                } else {
775                    None
776                }
777            }
778
779            (AstNode::OrderByList(orders), PathComponent::Index(idx)) => {
780                let mut new_orders = orders.clone();
781                if *idx < new_orders.len() {
782                    if let Some(AstNode::OrderByExpr(order)) = new_child {
783                        new_orders[*idx] = order;
784                    } else {
785                        new_orders.remove(*idx);
786                    }
787                    Some(AstNode::OrderByList(new_orders))
788                } else {
789                    None
790                }
791            }
792
793            // Expression field modifications
794            (AstNode::Expr(expr), PathComponent::Field(field)) => match (expr, field) {
795                (Expr::BinaryOp { left: _, op, right }, AstField::Left) => {
796                    if let Some(AstNode::Expr(new_left)) = new_child {
797                        Some(AstNode::Expr(Expr::BinaryOp {
798                            left: Box::new(new_left),
799                            op: op.clone(),
800                            right: right.clone(),
801                        }))
802                    } else {
803                        None
804                    }
805                }
806                (Expr::BinaryOp { left, op, right: _ }, AstField::Right) => {
807                    if let Some(AstNode::Expr(new_right)) = new_child {
808                        Some(AstNode::Expr(Expr::BinaryOp {
809                            left: left.clone(),
810                            op: op.clone(),
811                            right: Box::new(new_right),
812                        }))
813                    } else {
814                        None
815                    }
816                }
817                (Expr::Nested(_), AstField::Inner) => {
818                    if let Some(AstNode::Expr(new_inner)) = new_child {
819                        Some(AstNode::Expr(Expr::Nested(Box::new(new_inner))))
820                    } else {
821                        None
822                    }
823                }
824                _ => None,
825            },
826
827            // SelectItem field modifications
828            (AstNode::SelectItem(select_item), PathComponent::Field(field)) => {
829                match (select_item, field) {
830                    (SelectItem::UnnamedExpr(_), AstField::Expr) => {
831                        if let Some(AstNode::Expr(new_expr)) = new_child {
832                            Some(AstNode::SelectItem(SelectItem::UnnamedExpr(new_expr)))
833                        } else {
834                            None
835                        }
836                    }
837                    (SelectItem::ExprWithAlias { alias, .. }, AstField::Expr) => {
838                        if let Some(AstNode::Expr(new_expr)) = new_child {
839                            Some(AstNode::SelectItem(SelectItem::ExprWithAlias {
840                                expr: new_expr,
841                                alias: alias.clone(),
842                            }))
843                        } else {
844                            None
845                        }
846                    }
847                    _ => None,
848                }
849            }
850
851            // OrderByExpr field modifications
852            (AstNode::OrderByExpr(order_by), PathComponent::Field(field)) => match field {
853                AstField::Expr => {
854                    if let Some(AstNode::Expr(new_expr)) = new_child {
855                        Some(AstNode::OrderByExpr(OrderByExpr {
856                            expr: new_expr,
857                            asc: order_by.asc,
858                            nulls_first: order_by.nulls_first,
859                        }))
860                    } else {
861                        None
862                    }
863                }
864                _ => None,
865            },
866
867            // With clause modifications
868            (AstNode::With(with_clause), PathComponent::Field(field)) => match field {
869                AstField::CteTable => {
870                    if let Some(AstNode::CteList(new_ctes)) = new_child {
871                        let mut new_with = with_clause.clone();
872                        new_with.cte_tables = new_ctes;
873                        Some(AstNode::With(new_with))
874                    } else {
875                        // Remove all CTEs by setting to empty vec
876                        let mut new_with = with_clause.clone();
877                        new_with.cte_tables = vec![];
878                        Some(AstNode::With(new_with))
879                    }
880                }
881                _ => None,
882            },
883
884            // CTE list modifications
885            (AstNode::CteList(ctes), PathComponent::Index(idx)) => {
886                if let Some(AstNode::Cte(new_cte)) = new_child {
887                    if *idx < ctes.len() {
888                        let mut new_ctes = ctes.clone();
889                        new_ctes[*idx] = new_cte;
890                        Some(AstNode::CteList(new_ctes))
891                    } else {
892                        None
893                    }
894                } else {
895                    None
896                }
897            }
898
899            // CTE modifications
900            (AstNode::Cte(cte), PathComponent::Field(field)) => match field {
901                AstField::CteInner => {
902                    if let Some(AstNode::Query(new_query)) = new_child {
903                        let mut new_cte = cte.clone();
904                        new_cte.cte_inner = CteInner::Query(new_query);
905                        Some(AstNode::Cte(new_cte))
906                    } else {
907                        None
908                    }
909                }
910                _ => None,
911            },
912
913            _ => {
914                // Add debug logging for unmatched cases
915                tracing::debug!(
916                    "set_child: No match for {} ({:?}) with component {:?}",
917                    AstNodeType::from(self),
918                    std::mem::discriminant(self),
919                    component
920                );
921                None
922            }
923        }
924    }
925
926    /// Convert back to a Statement if this is the root node.
927    pub fn to_statement(&self) -> Option<Statement> {
928        match self {
929            AstNode::Statement(stmt) => Some(stmt.clone()),
930            _ => None,
931        }
932    }
933}
934
935/// Navigate to a node in the AST using the given path.
936/// Enables precise AST node retrieval.
937pub fn get_node_at_path(root: &AstNode, path: &AstPath) -> Option<AstNode> {
938    let mut current = root.clone();
939    for component in path {
940        current = current.get_child(component)?;
941    }
942    Some(current)
943}
944
945/// Set a node in the AST at the given path.
946/// Enables precise AST node modification.
947pub fn set_node_at_path(
948    root: &AstNode,
949    path: &AstPath,
950    new_node: Option<AstNode>,
951) -> Option<AstNode> {
952    if path.is_empty() {
953        return new_node;
954    }
955
956    let result = root.clone();
957    let mut current_path = Vec::new();
958
959    // Navigate to the parent of the target node
960    for component in &path[..path.len() - 1] {
961        current_path.push(component.clone());
962    }
963
964    // Get the parent node
965    let parent = get_node_at_path(&result, &current_path)?;
966
967    // Apply the modification to the parent
968    let modified_parent = parent.set_child(&path[path.len() - 1], new_node)?;
969
970    // Now we need to set this modified parent back in the tree
971    if current_path.is_empty() {
972        Some(modified_parent)
973    } else {
974        set_node_at_path(&result, &current_path, Some(modified_parent))
975    }
976}
977
978/// Helper function to get a child node and recurse if it exists.
979fn explore_child_field(
980    node: &AstNode,
981    field: AstField,
982    current_path: &AstPath,
983    paths: &mut Vec<AstPath>,
984) {
985    let field_component = PathComponent::field(field);
986    let child_path = [current_path.clone(), vec![field_component.clone()]].concat();
987    let relative_path = vec![field_component];
988
989    if let Some(child_node) = get_node_at_path(node, &relative_path) {
990        // Collect child paths but don't add them yet (for outer-first ordering)
991        let child_paths = enumerate_reduction_paths(&child_node, child_path);
992        paths.extend(child_paths);
993    }
994}
995
996/// Calculate the depth of a path (number of components).
997/// Used for outer-first ordering: shallower paths (outer queries) come first.
998fn path_depth(path: &AstPath) -> usize {
999    path.len()
1000}
1001
1002/// Enumerate all interesting paths in the AST for reduction.
1003/// Systematically discovers all reducible AST locations.
1004pub fn enumerate_reduction_paths(node: &AstNode, current_path: AstPath) -> Vec<AstPath> {
1005    let mut paths = vec![current_path.clone()];
1006
1007    tracing::debug!(
1008        "Enumerating paths for node {:?} at path {}",
1009        get_node_type_name(node),
1010        display_ast_path(&current_path)
1011    );
1012
1013    match node {
1014        AstNode::Statement(Statement::Query(_)) => {
1015            explore_child_field(node, AstField::Query, &current_path, &mut paths);
1016        }
1017
1018        AstNode::Statement(Statement::CreateView { .. }) => {
1019            explore_child_field(node, AstField::Query, &current_path, &mut paths);
1020        }
1021
1022        AstNode::Query(_query) => {
1023            explore_child_field(node, AstField::Body, &current_path, &mut paths);
1024            explore_child_field(node, AstField::With, &current_path, &mut paths);
1025            explore_child_field(node, AstField::OrderBy, &current_path, &mut paths);
1026        }
1027
1028        AstNode::Select(_) => {
1029            explore_child_field(node, AstField::Projection, &current_path, &mut paths);
1030            explore_child_field(node, AstField::From, &current_path, &mut paths);
1031            explore_child_field(node, AstField::GroupBy, &current_path, &mut paths);
1032
1033            // For optional fields, just add the path if they exist
1034            let selection_path = [current_path.clone(), vec![PathComponent::selection()]].concat();
1035            if get_node_at_path(node, &vec![PathComponent::selection()]).is_some() {
1036                paths.push(selection_path);
1037            }
1038
1039            let having_path = [current_path.clone(), vec![PathComponent::having()]].concat();
1040            if get_node_at_path(node, &vec![PathComponent::having()]).is_some() {
1041                paths.push(having_path);
1042            }
1043        }
1044
1045        // For lists, enumerate individual elements
1046        AstNode::SelectItemList(items) => {
1047            for i in 0..items.len() {
1048                let item_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1049                paths.push(item_path);
1050            }
1051        }
1052
1053        AstNode::ExprList(exprs) => {
1054            for i in 0..exprs.len() {
1055                let expr_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1056                paths.push(expr_path.clone());
1057                // Also descend into expressions for pullup opportunities
1058                let relative_path = vec![PathComponent::Index(i)];
1059                if let Some(expr_node) = get_node_at_path(node, &relative_path) {
1060                    paths.extend(enumerate_reduction_paths(&expr_node, expr_path));
1061                }
1062            }
1063        }
1064
1065        AstNode::TableList(tables) => {
1066            for i in 0..tables.len() {
1067                let table_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1068                paths.push(table_path.clone());
1069                // Recursively explore TableWithJoins
1070                if let Some(table_node) = node.get_child(&PathComponent::Index(i)) {
1071                    paths.extend(enumerate_reduction_paths(&table_node, table_path));
1072                }
1073            }
1074        }
1075
1076        // TableWithJoins path enumeration - key for JOIN reduction
1077        AstNode::TableWithJoins(_) => {
1078            explore_child_field(node, AstField::Relation, &current_path, &mut paths);
1079            explore_child_field(node, AstField::Joins, &current_path, &mut paths);
1080        }
1081
1082        // Join path enumeration
1083        AstNode::Join(_) => {
1084            explore_child_field(node, AstField::Relation, &current_path, &mut paths);
1085        }
1086
1087        // TableFactor path enumeration
1088        AstNode::TableFactor(_) => {
1089            // For Derived tables (subqueries), explore the subquery
1090            explore_child_field(node, AstField::Subquery, &current_path, &mut paths);
1091        }
1092
1093        // JoinList path enumeration
1094        AstNode::JoinList(joins) => {
1095            for i in 0..joins.len() {
1096                let join_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1097                paths.push(join_path.clone());
1098                // Recursively explore Join
1099                if let Some(join_node) = node.get_child(&PathComponent::Index(i)) {
1100                    paths.extend(enumerate_reduction_paths(&join_node, join_path));
1101                }
1102            }
1103        }
1104
1105        AstNode::OrderByList(orders) => {
1106            for i in 0..orders.len() {
1107                let order_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1108                paths.push(order_path);
1109            }
1110        }
1111
1112        // For expressions, look for pullup opportunities
1113        AstNode::Expr(expr) => match expr {
1114            Expr::BinaryOp { .. } => {
1115                let left_path = [current_path.clone(), vec![PathComponent::left()]].concat();
1116                let right_path = [current_path.clone(), vec![PathComponent::right()]].concat();
1117                paths.push(left_path);
1118                paths.push(right_path);
1119            }
1120            Expr::Case {
1121                operand,
1122                else_result,
1123                ..
1124            } => {
1125                if operand.is_some() {
1126                    let operand_path =
1127                        [current_path.clone(), vec![PathComponent::operand()]].concat();
1128                    paths.push(operand_path);
1129                }
1130                if else_result.is_some() {
1131                    let else_path =
1132                        [current_path.clone(), vec![PathComponent::else_result()]].concat();
1133                    paths.push(else_path);
1134                }
1135            }
1136            Expr::Exists(_) => {
1137                explore_child_field(node, AstField::Subquery, &current_path, &mut paths);
1138            }
1139            Expr::Subquery(_) => {
1140                explore_child_field(node, AstField::Subquery, &current_path, &mut paths);
1141            }
1142            _ => {}
1143        },
1144
1145        // WITH clause enumeration
1146        AstNode::With(_) => {
1147            explore_child_field(node, AstField::CteTable, &current_path, &mut paths);
1148        }
1149
1150        // CTE list enumeration
1151        AstNode::CteList(ctes) => {
1152            for i in 0..ctes.len() {
1153                let cte_path = [current_path.clone(), vec![PathComponent::Index(i)]].concat();
1154                paths.push(cte_path.clone());
1155
1156                // Recursively enumerate paths within each CTE
1157                if let Some(cte_node) = get_node_at_path(node, &vec![PathComponent::Index(i)]) {
1158                    paths.extend(enumerate_reduction_paths(&cte_node, cte_path));
1159                }
1160            }
1161        }
1162
1163        // CTE enumeration
1164        AstNode::Cte(_) => {
1165            explore_child_field(node, AstField::CteInner, &current_path, &mut paths);
1166        }
1167
1168        _ => {}
1169    }
1170
1171    // Sort paths by depth (outer-first): shallower paths come first
1172    // This ensures outer queries are reduced before inner subqueries
1173    // For example: SELECT (SELECT ...) will reduce the outer SELECT first
1174    paths.sort_by_key(path_depth);
1175
1176    paths
1177}
1178
1179/// Convert a Statement to an `AstNode` for path-based operations.
1180pub fn statement_to_ast_node(stmt: &Statement) -> AstNode {
1181    AstNode::Statement(stmt.clone())
1182}
1183
1184/// Extract a Statement from an `AstNode`.
1185pub fn ast_node_to_statement(node: &AstNode) -> Option<Statement> {
1186    match node {
1187        AstNode::Statement(stmt) => Some(stmt.clone()),
1188        AstNode::Query(query) => Some(Statement::Query(Box::new(query.as_ref().clone()))),
1189        _ => None,
1190    }
1191}
1192
1193/// Get a human-readable name for an AST node type.
1194pub fn get_node_type_name(node: &AstNode) -> String {
1195    AstNodeType::from(node).to_string()
1196}
1197
1198#[cfg(test)]
1199mod tests {
1200    use risingwave_sqlparser::parser::Parser;
1201
1202    use super::*;
1203
1204    #[test]
1205    fn test_path_enumeration() {
1206        let sql = "SELECT a FROM b;";
1207        let parsed = Parser::parse_sql(sql).expect("Failed to parse SQL");
1208        let stmt = &parsed[0];
1209        let ast_node = statement_to_ast_node(stmt);
1210
1211        let paths = enumerate_reduction_paths(&ast_node, vec![]);
1212
1213        // Should have at least a few paths for this simple query
1214        assert!(paths.len() >= 3);
1215        println!("Found {} paths for simple query", paths.len());
1216    }
1217
1218    #[test]
1219    fn test_create_materialized_view() {
1220        let sql = "CREATE MATERIALIZED VIEW stream_query AS SELECT min((tumble_0.c5 + tumble_0.c5)) AS col_0, false AS col_1, TIME '07:30:48' AS col_2, tumble_0.c14 AS col_3 FROM tumble(alltypes1, alltypes1.c11, INTERVAL '10') AS tumble_0 WHERE tumble_0.c1 GROUP BY tumble_0.c10, tumble_0.c9, tumble_0.c13, tumble_0.c1, tumble_0.c16, tumble_0.c14, tumble_0.c5, tumble_0.c8;";
1221        let parsed = Parser::parse_sql(sql).expect("Failed to parse SQL");
1222        let stmt = &parsed[0];
1223        let ast_node = statement_to_ast_node(stmt);
1224
1225        let paths = enumerate_reduction_paths(&ast_node, vec![]);
1226
1227        println!("Found {} paths for CREATE MATERIALIZED VIEW", paths.len());
1228        for (i, path) in paths.iter().enumerate() {
1229            println!("Path {}: {}", i, display_ast_path(path));
1230        }
1231
1232        // Should have paths for the SELECT statement inside the MV
1233        assert!(!paths.is_empty());
1234
1235        // Should be able to get some node
1236        assert!(get_node_at_path(&ast_node, &paths[0]).is_some());
1237    }
1238}