risingwave_frontend/expr/
subquery.rs1use std::hash::Hash;
16
17use risingwave_common::types::{DataType, StructType};
18
19use super::{Expr, ExprImpl, ExprType};
20use crate::binder::{BoundQuery, UNNAMED_COLUMN};
21use crate::expr::{CorrelatedId, Depth};
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum SubqueryKind {
25 Scalar,
27 UpdateSet,
30 Existential,
32 In(ExprImpl),
34 Some(ExprImpl, ExprType),
36 All(ExprImpl, ExprType),
38 Array,
40}
41
42#[derive(Clone)]
44pub struct Subquery {
45 pub query: BoundQuery,
46 pub kind: SubqueryKind,
47}
48
49impl Subquery {
50 pub fn new(query: BoundQuery, kind: SubqueryKind) -> Self {
51 Self { query, kind }
52 }
53
54 pub fn is_correlated_by_depth(&self, depth: Depth) -> bool {
55 let is_correlated = match &self.kind {
56 SubqueryKind::In(expr) => expr.has_correlated_input_ref_by_depth(depth),
57 SubqueryKind::Some(expr, _) | SubqueryKind::All(expr, _) => {
58 expr.has_correlated_input_ref_by_depth(depth)
59 }
60 SubqueryKind::Array
61 | SubqueryKind::Scalar
62 | SubqueryKind::UpdateSet
63 | SubqueryKind::Existential => false,
64 };
65 is_correlated || self.query.is_correlated_by_depth(depth)
66 }
67
68 pub fn is_correlated_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
69 let is_correlated = match &self.kind {
70 SubqueryKind::In(expr) => expr.has_correlated_input_ref_by_correlated_id(correlated_id),
71 SubqueryKind::Some(expr, _) | SubqueryKind::All(expr, _) => {
72 expr.has_correlated_input_ref_by_correlated_id(correlated_id)
73 }
74 SubqueryKind::Array
75 | SubqueryKind::Scalar
76 | SubqueryKind::UpdateSet
77 | SubqueryKind::Existential => false,
78 };
79 is_correlated || self.query.is_correlated_by_correlated_id(correlated_id)
80 }
81
82 pub fn collect_correlated_indices_by_depth_and_assign_id(
83 &mut self,
84 depth: Depth,
85 correlated_id: CorrelatedId,
86 ) -> Vec<usize> {
87 let mut correlated_indices = self
88 .query
89 .collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id);
90
91 match &mut self.kind {
92 SubqueryKind::In(expr) => {
93 correlated_indices.extend(
94 expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
95 );
96 }
97 SubqueryKind::Some(expr, _) | SubqueryKind::All(expr, _) => {
98 correlated_indices.extend(
99 expr.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
100 );
101 }
102 SubqueryKind::Array
103 | SubqueryKind::Scalar
104 | SubqueryKind::UpdateSet
105 | SubqueryKind::Existential => {
106 }
108 }
109 correlated_indices.sort();
110 correlated_indices.dedup();
111 correlated_indices
112 }
113}
114
115impl PartialEq for Subquery {
116 fn eq(&self, _other: &Self) -> bool {
117 unreachable!("Subquery {:?} has not been unnested", self)
118 }
119}
120
121impl Hash for Subquery {
122 fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {
123 unreachable!("Subquery {:?} has not been hashed", self)
124 }
125}
126
127impl Eq for Subquery {}
128
129impl Expr for Subquery {
130 fn return_type(&self) -> DataType {
131 match self.kind {
132 SubqueryKind::Scalar => {
133 let types = self.query.data_types();
134 assert_eq!(types.len(), 1, "Subquery with more than one column");
135 types[0].clone()
136 }
137 SubqueryKind::UpdateSet => {
138 let schema = self.query.schema();
139 let struct_type = if schema.fields().iter().any(|f| f.name == UNNAMED_COLUMN) {
140 StructType::unnamed(self.query.data_types())
141 } else {
142 StructType::new(
143 (schema.fields().iter().cloned()).map(|f| (f.name, f.data_type)),
144 )
145 };
146 DataType::Struct(struct_type)
147 }
148 SubqueryKind::Array => {
149 let types = self.query.data_types();
150 assert_eq!(types.len(), 1, "Subquery with more than one column");
151 DataType::List(types[0].clone().into())
152 }
153 _ => DataType::Boolean,
154 }
155 }
156
157 fn to_expr_proto(&self) -> risingwave_pb::expr::ExprNode {
158 unreachable!("Subquery {:?} has not been unnested", self)
159 }
160}
161
162impl std::fmt::Debug for Subquery {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("Subquery")
165 .field("kind", &self.kind)
166 .field("query", &self.query)
167 .finish()
168 }
169}