risingwave_frontend/expr/
subquery.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::hash::Hash;
16
17use risingwave_common::types::{DataType, StructType};
18
19use super::{Expr, ExprImpl, ExprType};
20use crate::binder::{BoundQuery, UNNAMED_COLUMN};
21use crate::expr::{CorrelatedId, Depth};
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum SubqueryKind {
25    /// Returns a scalar value (single column single row).
26    Scalar,
27    /// Returns a scalar struct value composed of multiple columns.
28    /// Used in `UPDATE SET (col1, col2) = (SELECT ...)`.
29    UpdateSet,
30    /// `EXISTS` | `NOT EXISTS` subquery (semi/anti-semi join). Returns a boolean.
31    Existential,
32    /// `IN` subquery.
33    In(ExprImpl),
34    /// Expression operator `SOME` subquery.
35    Some(ExprImpl, ExprType),
36    /// Expression operator `ALL` subquery.
37    All(ExprImpl, ExprType),
38    /// Expression operator `ARRAY` subquery.
39    Array,
40}
41
42/// Subquery expression.
43#[derive(Clone)]
44pub struct Subquery {
45    pub query: BoundQuery,
46    pub kind: SubqueryKind,
47}
48
49impl Subquery {
50    pub fn new(query: BoundQuery, kind: SubqueryKind) -> Self {
51        Self { query, kind }
52    }
53
54    pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
55        let is_correlated = match &self.kind {
56            SubqueryKind::In(expr) => expr.has_correlated_input_ref_by_depth(depth),
57            SubqueryKind::Some(expr, _) | SubqueryKind::All(expr, _) => {
58                expr.has_correlated_input_ref_by_depth(depth)
59            }
60            SubqueryKind::Array
61            | SubqueryKind::Scalar
62            | SubqueryKind::UpdateSet
63            | SubqueryKind::Existential => false,
64        };
65        is_correlated || self.query.is_correlated_by_depth(depth)
66    }
67
68    pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
69        let is_correlated = match &self.kind {
70            SubqueryKind::In(expr) => expr.has_correlated_input_ref_by_correlated_id(correlated_id),
71            SubqueryKind::Some(expr, _) | SubqueryKind::All(expr, _) => {
72                expr.has_correlated_input_ref_by_correlated_id(correlated_id)
73            }
74            SubqueryKind::Array
75            | SubqueryKind::Scalar
76            | SubqueryKind::UpdateSet
77            | SubqueryKind::Existential => false,
78        };
79        is_correlated || self.query.is_correlated_by_correlated_id(correlated_id)
80    }
81
82    pub fn collect_correlated_indices_by_depth_and_assign_id(
83        &mut self,
84        depth: Depth,
85        correlated_id: CorrelatedId,
86    ) -> Vec<usize> {
87        let mut correlated_indices = self
88            .query
89            .collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id);
90
91        match &mut self.kind {
92            SubqueryKind::In(expr) => {
93                correlated_indices.extend(
94                    expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
95                );
96            }
97            SubqueryKind::Some(expr, _) | SubqueryKind::All(expr, _) => {
98                correlated_indices.extend(
99                    expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
100                );
101            }
102            SubqueryKind::Array
103            | SubqueryKind::Scalar
104            | SubqueryKind::UpdateSet
105            | SubqueryKind::Existential => {
106                // No additional correlated indices to collect for these kinds.
107            }
108        }
109        correlated_indices.sort();
110        correlated_indices.dedup();
111        correlated_indices
112    }
113}
114
115impl PartialEq for Subquery {
116    fn eq(&self, _other: &Self) -> bool {
117        unreachable!("Subquery {:?} has not been unnested", self)
118    }
119}
120
121impl Hash for Subquery {
122    fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {
123        unreachable!("Subquery {:?} has not been hashed", self)
124    }
125}
126
127impl Eq for Subquery {}
128
129impl Expr for Subquery {
130    fn return_type(&self) -> DataType {
131        match self.kind {
132            SubqueryKind::Scalar => {
133                let types = self.query.data_types();
134                assert_eq!(types.len(), 1, "Subquery with more than one column");
135                types[0].clone()
136            }
137            SubqueryKind::UpdateSet => {
138                let schema = self.query.schema();
139                let struct_type = if schema.fields().iter().any(|f| f.name == UNNAMED_COLUMN) {
140                    StructType::unnamed(self.query.data_types())
141                } else {
142                    StructType::new(
143                        (schema.fields().iter().cloned()).map(|f| (f.name, f.data_type)),
144                    )
145                };
146                DataType::Struct(struct_type)
147            }
148            SubqueryKind::Array => {
149                let types = self.query.data_types();
150                assert_eq!(types.len(), 1, "Subquery with more than one column");
151                DataType::List(types[0].clone().into())
152            }
153            _ => DataType::Boolean,
154        }
155    }
156
157    fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
158        unreachable!("Subquery {:?} has not been unnested", self)
159    }
160}
161
162impl std::fmt::Debug for Subquery {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.debug_struct("Subquery")
165            .field("kind", &self.kind)
166            .field("query", &self.query)
167            .finish()
168    }
169}