risingwave_frontend/expr/
subquery.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::hash::Hash;
16
17use risingwave_common::types::{DataType, StructType};
18
19use super::{Expr, ExprImpl, ExprType};
20use crate::binder::{BoundQuery, UNNAMED_COLUMN};
21use crate::expr::{CorrelatedId, Depth};
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum SubqueryKind {
25    /// Returns a scalar value (single column single row).
26    Scalar,
27    /// Returns a scalar struct value composed of multiple columns.
28    /// Used in `UPDATE SET (col1, col2) = (SELECT ...)`.
29    UpdateSet,
30    /// `EXISTS` | `NOT EXISTS` subquery (semi/anti-semi join). Returns a boolean.
31    Existential,
32    /// `IN` subquery.
33    In(ExprImpl),
34    /// Expression operator `SOME` subquery.
35    Some(ExprImpl, ExprType),
36    /// Expression operator `ALL` subquery.
37    All(ExprImpl, ExprType),
38    /// Expression operator `ARRAY` subquery.
39    Array,
40}
41
42/// Subquery expression.
43#[derive(Clone)]
44pub struct Subquery {
45    pub query: BoundQuery,
46    pub kind: SubqueryKind,
47}
48
49impl Subquery {
50    pub fn new(query: BoundQuery, kind: SubqueryKind) -> Self {
51        Self { query, kind }
52    }
53
54    pub fn is_correlated(&self, depth: Depth) -> bool {
55        self.query.is_correlated(depth)
56    }
57
58    pub fn collect_correlated_indices_by_depth_and_assign_id(
59        &mut self,
60        depth: Depth,
61        correlated_id: CorrelatedId,
62    ) -> Vec<usize> {
63        let mut correlated_indices = self
64            .query
65            .collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id);
66        correlated_indices.sort();
67        correlated_indices.dedup();
68        correlated_indices
69    }
70}
71
72impl PartialEq for Subquery {
73    fn eq(&self, _other: &Self) -> bool {
74        unreachable!("Subquery {:?} has not been unnested", self)
75    }
76}
77
78impl Hash for Subquery {
79    fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {
80        unreachable!("Subquery {:?} has not been hashed", self)
81    }
82}
83
84impl Eq for Subquery {}
85
86impl Expr for Subquery {
87    fn return_type(&self) -> DataType {
88        match self.kind {
89            SubqueryKind::Scalar => {
90                let types = self.query.data_types();
91                assert_eq!(types.len(), 1, "Subquery with more than one column");
92                types[0].clone()
93            }
94            SubqueryKind::UpdateSet => {
95                let schema = self.query.schema();
96                let struct_type = if schema.fields().iter().any(|f| f.name == UNNAMED_COLUMN) {
97                    StructType::unnamed(self.query.data_types())
98                } else {
99                    StructType::new(
100                        (schema.fields().iter().cloned()).map(|f| (f.name, f.data_type)),
101                    )
102                };
103                DataType::Struct(struct_type)
104            }
105            SubqueryKind::Array => {
106                let types = self.query.data_types();
107                assert_eq!(types.len(), 1, "Subquery with more than one column");
108                DataType::List(types[0].clone().into())
109            }
110            _ => DataType::Boolean,
111        }
112    }
113
114    fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
115        unreachable!("Subquery {:?} has not been unnested", self)
116    }
117}
118
119impl std::fmt::Debug for Subquery {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("Subquery")
122            .field("kind", &self.kind)
123            .field("query", &self.query)
124            .finish()
125    }
126}