risingwave_frontend/utils/
column_index_mapping.rs1use std::vec;
16
17use fixedbitset::FixedBitSet;
18pub use risingwave_common::util::column_index_mapping::ColIndexMapping;
19use risingwave_common::util::sort_util::ColumnOrder;
20
21use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
22use crate::optimizer::property::{
23 Distribution, FunctionalDependency, FunctionalDependencySet, MonotonicityMap, Order,
24 RequiredDist,
25};
26
27#[easy_ext::ext(ColIndexMappingRewriteExt)]
29impl ColIndexMapping {
30 pub fn rewrite_provided_order(&self, order: &Order) -> Order {
34 let mut mapped_column_orders = vec![];
35 for column_order in &order.column_orders {
36 match self.try_map(column_order.column_index) {
37 Some(mapped_index) => mapped_column_orders
38 .push(ColumnOrder::new(mapped_index, column_order.order_type)),
39 None => break,
40 }
41 }
42 Order {
43 column_orders: mapped_column_orders,
44 }
45 }
46
47 pub fn rewrite_required_order(&self, order: &Order) -> Option<Order> {
52 order
53 .column_orders
54 .iter()
55 .map(|o| {
56 self.try_map(o.column_index)
57 .map(|mapped_index| ColumnOrder::new(mapped_index, o.order_type))
58 })
59 .collect::<Option<Vec<_>>>()
60 .map(|mapped_column_orders| Order {
61 column_orders: mapped_column_orders,
62 })
63 }
64
65 pub fn rewrite_dist_key(&self, key: &[usize]) -> Option<Vec<usize>> {
68 self.try_map_all(key.iter().copied())
69 }
70
71 pub fn rewrite_provided_distribution(&self, dist: &Distribution) -> Distribution {
78 match dist {
79 Distribution::Single | Distribution::SomeShard | Distribution::Broadcast => {
80 dist.clone()
81 }
82
83 Distribution::HashShard(_) => match self.rewrite_dist_key(dist.dist_column_indices()) {
84 Some(mapped_dist_key) => Distribution::HashShard(mapped_dist_key),
85 None => Distribution::SomeShard,
86 },
87 Distribution::UpstreamHashShard(_, table_id) => {
88 match self.rewrite_dist_key(dist.dist_column_indices()) {
89 Some(mapped_dist_key) => {
90 Distribution::UpstreamHashShard(mapped_dist_key, *table_id)
91 }
92 None => Distribution::SomeShard,
93 }
94 }
95 }
96 }
97
98 pub fn rewrite_required_distribution(&self, dist: &RequiredDist) -> RequiredDist {
104 match dist {
105 RequiredDist::ShardByKey(keys) => {
106 assert!(!keys.is_clear());
107 let keys = self.rewrite_bitset(keys);
108 if keys.count_ones(..) == 0 {
109 RequiredDist::Any
110 } else {
111 RequiredDist::ShardByKey(keys)
112 }
113 }
114 RequiredDist::PhysicalDist(dist) => match dist {
115 Distribution::HashShard(keys) => {
116 assert!(!keys.is_empty());
117 let keys = self.rewrite_dist_key(keys);
118 match keys {
119 Some(keys) => RequiredDist::PhysicalDist(Distribution::HashShard(keys)),
120 None => RequiredDist::Any,
121 }
122 }
123 Distribution::UpstreamHashShard(keys, table_id) => {
124 assert!(!keys.is_empty());
125 let keys = self.rewrite_dist_key(keys);
126 match keys {
127 Some(keys) => RequiredDist::PhysicalDist(Distribution::UpstreamHashShard(
128 keys, *table_id,
129 )),
130 None => RequiredDist::Any,
131 }
132 }
133 Distribution::Single | Distribution::Broadcast | Distribution::SomeShard => {
134 RequiredDist::PhysicalDist(dist.clone())
135 }
136 },
137 RequiredDist::Any => RequiredDist::Any,
138 RequiredDist::AnyShard => RequiredDist::AnyShard,
139 }
140 }
141
142 pub fn rewrite_functional_dependency(
151 &self,
152 fd: &FunctionalDependency,
153 ) -> Option<FunctionalDependency> {
154 let new_from = self.rewrite_bitset(fd.from());
155 let new_to = self.rewrite_bitset(fd.to());
156 if new_from.count_ones(..) != fd.from().count_ones(..) || new_to.is_clear() {
157 None
158 } else {
159 Some(FunctionalDependency::new(new_from, new_to))
160 }
161 }
162
163 pub fn rewrite_functional_dependency_set(
171 &self,
172 fd_set: FunctionalDependencySet,
173 ) -> FunctionalDependencySet {
174 let mut new_fd_set = FunctionalDependencySet::new(self.target_size());
175 for i in fd_set.into_dependencies() {
176 if let Some(fd) = self.rewrite_functional_dependency(&i) {
177 new_fd_set.add_functional_dependency(fd);
178 }
179 }
180 new_fd_set
181 }
182
183 pub fn rewrite_bitset(&self, bitset: &FixedBitSet) -> FixedBitSet {
184 assert_eq!(bitset.len(), self.source_size());
185 let mut ret = FixedBitSet::with_capacity(self.target_size());
186 for i in bitset.ones() {
187 if let Some(i) = self.try_map(i) {
188 ret.insert(i);
189 }
190 }
191 ret
192 }
193
194 pub fn rewrite_monotonicity_map(&self, map: &MonotonicityMap) -> MonotonicityMap {
195 let mut new_map = MonotonicityMap::new();
196 for (i, monotonicity) in map.iter() {
197 if let Some(mapped_i) = self.try_map(i) {
198 new_map.insert(mapped_i, monotonicity);
199 }
200 }
201 new_map
202 }
203}
204
205impl ExprRewriter for ColIndexMapping {
206 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
207 InputRef::new(self.map(input_ref.index()), input_ref.return_type()).into()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_rewrite_fd() {
217 let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
218 let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
219 let fds_with_expected_res = vec![
220 (new_fd(&[0, 1], &[2, 3]), None),
221 (new_fd(&[2], &[0, 1]), None),
222 (
223 new_fd(&[1], &[0]),
224 Some(FunctionalDependency::with_indices(2, &[0], &[1])),
225 ),
226 ];
227 for (input, expected) in fds_with_expected_res {
228 assert_eq!(mapping.rewrite_functional_dependency(&input), expected);
229 }
230 }
231
232 #[test]
233 fn test_rewrite_fd_set() {
234 let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
235 let fd_set = FunctionalDependencySet::with_dependencies(
236 4,
237 vec![
238 new_fd(&[0, 1], &[2, 3]),
240 new_fd(&[2], &[0, 1]),
241 new_fd(&[0, 1, 2], &[3]),
242 new_fd(&[], &[]),
244 new_fd(&[1], &[]),
245 new_fd(&[], &[0]),
247 new_fd(&[1], &[0]),
249 ],
250 );
251 let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
252 let result = mapping.rewrite_functional_dependency_set(fd_set);
253 let expected = FunctionalDependencySet::with_dependencies(
254 2,
255 vec![
256 FunctionalDependency::with_indices(2, &[], &[1]),
257 FunctionalDependency::with_indices(2, &[0], &[1]),
258 ],
259 );
260 assert_eq!(result, expected);
261 }
262}