risingwave_frontend/utils/
column_index_mapping.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec;
16
17use fixedbitset::FixedBitSet;
18pub use risingwave_common::util::column_index_mapping::ColIndexMapping;
19use risingwave_common::util::sort_util::ColumnOrder;
20
21use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
22use crate::optimizer::property::{
23    Distribution, FunctionalDependency, FunctionalDependencySet, MonotonicityMap, Order,
24    RequiredDist,
25};
26
27/// Extension trait for [`ColIndexMapping`] to rewrite frontend structures.
28#[easy_ext::ext(ColIndexMappingRewriteExt)]
29impl ColIndexMapping {
30    /// Rewrite the provided order's column index. It will try its best to give the most accurate
31    /// order. Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2)
32    /// Order(0,1,2) with mapping(0->1,2->0) will be rewritten to Order(1)
33    pub fn rewrite_provided_order(&self, order: &Order) -> Order {
34        let mut mapped_column_orders = vec![];
35        for column_order in &order.column_orders {
36            match self.try_map(column_order.column_index) {
37                Some(mapped_index) => mapped_column_orders
38                    .push(ColumnOrder::new(mapped_index, column_order.order_type)),
39                None => break,
40            }
41        }
42        Order {
43            column_orders: mapped_column_orders,
44        }
45    }
46
47    /// Rewrite the required order's field index. if it can't give a corresponding
48    /// required order after the column index mapping, it will return None.
49    /// Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2)
50    /// Order(0,1,2) with mapping(0->1,2->0) will return None
51    pub fn rewrite_required_order(&self, order: &Order) -> Option<Order> {
52        order
53            .column_orders
54            .iter()
55            .map(|o| {
56                self.try_map(o.column_index)
57                    .map(|mapped_index| ColumnOrder::new(mapped_index, o.order_type))
58            })
59            .collect::<Option<Vec<_>>>()
60            .map(|mapped_column_orders| Order {
61                column_orders: mapped_column_orders,
62            })
63    }
64
65    /// Rewrite the distribution key and will return None if **any** index of the key disappear
66    /// after the mapping.
67    pub fn rewrite_dist_key(&self, key: &[usize]) -> Option<Vec<usize>> {
68        self.try_map_all(key.iter().copied())
69    }
70
71    /// Rewrite the provided distribution's field index. It will try its best to give the most
72    /// accurate distribution.
73    /// HashShard(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to HashShard(1,0,2).
74    /// HashShard(0,1,2), with mapping(0->1,2->0) will be rewritten to `SomeShard`.
75    ///
76    /// For non-HashShard distribution, it will return the original distribution.
77    pub fn rewrite_provided_distribution(&self, dist: &Distribution) -> Distribution {
78        match dist {
79            Distribution::Single | Distribution::SomeShard | Distribution::Broadcast => {
80                dist.clone()
81            }
82
83            Distribution::HashShard(_) => match self.rewrite_dist_key(dist.dist_column_indices()) {
84                Some(mapped_dist_key) => Distribution::HashShard(mapped_dist_key),
85                None => Distribution::SomeShard,
86            },
87            Distribution::UpstreamHashShard(_, table_id) => {
88                match self.rewrite_dist_key(dist.dist_column_indices()) {
89                    Some(mapped_dist_key) => {
90                        Distribution::UpstreamHashShard(mapped_dist_key, *table_id)
91                    }
92                    None => Distribution::SomeShard,
93                }
94            }
95        }
96    }
97
98    /// Rewrite the required distribution's field index. if it can't give a corresponding
99    /// required distribution after the column index mapping, it will return None.
100    /// ShardByKey(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to ShardByKey(1,0,2).
101    /// ShardByKey(0,1,2), with mapping(0->1,2->0) will return ShardByKey(1,0).
102    /// ShardByExactKey(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to ShardByExactKey(1,0,2).
103    /// ShardByExactKey(0,1,2), with mapping(0->1,2->0) will return `Any`.
104    /// ShardByKey(0,1), with mapping(2->0) will return `Any`.
105    pub fn rewrite_required_distribution(&self, dist: &RequiredDist) -> RequiredDist {
106        match dist {
107            RequiredDist::ShardByExactKey(keys) => {
108                assert!(!keys.is_clear());
109                let original_key_count = keys.count_ones(..);
110                let keys = self.rewrite_bitset(keys);
111                if keys.count_ones(..) != original_key_count {
112                    RequiredDist::Any
113                } else {
114                    RequiredDist::ShardByExactKey(keys)
115                }
116            }
117            RequiredDist::ShardByKey(keys) => {
118                assert!(!keys.is_clear());
119                let keys = self.rewrite_bitset(keys);
120                if keys.count_ones(..) == 0 {
121                    RequiredDist::Any
122                } else {
123                    RequiredDist::ShardByKey(keys)
124                }
125            }
126            RequiredDist::PhysicalDist(dist) => match dist {
127                Distribution::HashShard(keys) => {
128                    assert!(!keys.is_empty());
129                    let keys = self.rewrite_dist_key(keys);
130                    match keys {
131                        Some(keys) => RequiredDist::PhysicalDist(Distribution::HashShard(keys)),
132                        None => RequiredDist::Any,
133                    }
134                }
135                Distribution::UpstreamHashShard(keys, table_id) => {
136                    assert!(!keys.is_empty());
137                    let keys = self.rewrite_dist_key(keys);
138                    match keys {
139                        Some(keys) => RequiredDist::PhysicalDist(Distribution::UpstreamHashShard(
140                            keys, *table_id,
141                        )),
142                        None => RequiredDist::Any,
143                    }
144                }
145                Distribution::Single | Distribution::Broadcast | Distribution::SomeShard => {
146                    RequiredDist::PhysicalDist(dist.clone())
147                }
148            },
149            RequiredDist::Any => RequiredDist::Any,
150            RequiredDist::AnyShard => RequiredDist::AnyShard,
151        }
152    }
153
154    /// Rewrite the indices in a functional dependency.
155    ///
156    /// If some columns in the `from` side are removed, then this fd is no longer valid. For
157    /// example, for ABC --> D, it means that A, B, and C together can determine D. But if B is
158    /// removed, this fd is not valid. For this case, we will return [`None`]
159    ///
160    /// Additionally, If the `to` side of a functional dependency becomes empty after rewriting, it
161    /// means that this dependency is unneeded so we also return [`None`].
162    pub fn rewrite_functional_dependency(
163        &self,
164        fd: &FunctionalDependency,
165    ) -> Option<FunctionalDependency> {
166        let new_from = self.rewrite_bitset(fd.from());
167        let new_to = self.rewrite_bitset(fd.to());
168        if new_from.count_ones(..) != fd.from().count_ones(..) || new_to.is_clear() {
169            None
170        } else {
171            Some(FunctionalDependency::new(new_from, new_to))
172        }
173    }
174
175    /// Rewrite functional dependencies in `fd_set` one by one, using
176    /// `[ColIndexMapping::rewrite_functional_dependency]`.
177    ///
178    /// Note that this rewrite process handles each function dependency independently.
179    /// Relationships within function dependencies are not considered.
180    /// For example, if we have `fd_set` { AB --> C, A --> B }, and column B is removed.
181    /// The result would be an empty `fd_set`, rather than { A --> C }.
182    pub fn rewrite_functional_dependency_set(
183        &self,
184        fd_set: FunctionalDependencySet,
185    ) -> FunctionalDependencySet {
186        let mut new_fd_set = FunctionalDependencySet::new(self.target_size());
187        for i in fd_set.into_dependencies() {
188            if let Some(fd) = self.rewrite_functional_dependency(&i) {
189                new_fd_set.add_functional_dependency(fd);
190            }
191        }
192        new_fd_set
193    }
194
195    pub fn rewrite_bitset(&self, bitset: &FixedBitSet) -> FixedBitSet {
196        assert_eq!(bitset.len(), self.source_size());
197        let mut ret = FixedBitSet::with_capacity(self.target_size());
198        for i in bitset.ones() {
199            if let Some(i) = self.try_map(i) {
200                ret.insert(i);
201            }
202        }
203        ret
204    }
205
206    pub fn rewrite_monotonicity_map(&self, map: &MonotonicityMap) -> MonotonicityMap {
207        let mut new_map = MonotonicityMap::new();
208        for (i, monotonicity) in map.iter() {
209            if let Some(mapped_i) = self.try_map(i) {
210                new_map.insert(mapped_i, monotonicity);
211            }
212        }
213        new_map
214    }
215}
216
217impl ExprRewriter for ColIndexMapping {
218    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
219        InputRef::new(self.map(input_ref.index()), input_ref.return_type()).into()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_rewrite_fd() {
229        let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
230        let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
231        let fds_with_expected_res = vec![
232            (new_fd(&[0, 1], &[2, 3]), None),
233            (new_fd(&[2], &[0, 1]), None),
234            (
235                new_fd(&[1], &[0]),
236                Some(FunctionalDependency::with_indices(2, &[0], &[1])),
237            ),
238        ];
239        for (input, expected) in fds_with_expected_res {
240            assert_eq!(mapping.rewrite_functional_dependency(&input), expected);
241        }
242    }
243
244    #[test]
245    fn test_rewrite_fd_set() {
246        let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
247        let fd_set = FunctionalDependencySet::with_dependencies(
248            4,
249            vec![
250                // removed
251                new_fd(&[0, 1], &[2, 3]),
252                new_fd(&[2], &[0, 1]),
253                new_fd(&[0, 1, 2], &[3]),
254                // empty mappings will be removed
255                new_fd(&[], &[]),
256                new_fd(&[1], &[]),
257                // constant column mapping will be kept
258                new_fd(&[], &[0]),
259                // kept
260                new_fd(&[1], &[0]),
261            ],
262        );
263        let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
264        let result = mapping.rewrite_functional_dependency_set(fd_set);
265        let expected = FunctionalDependencySet::with_dependencies(
266            2,
267            vec![
268                FunctionalDependency::with_indices(2, &[], &[1]),
269                FunctionalDependency::with_indices(2, &[0], &[1]),
270            ],
271        );
272        assert_eq!(result, expected);
273    }
274}