risingwave_frontend/utils/
column_index_mapping.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec;
16
17use fixedbitset::FixedBitSet;
18pub use risingwave_common::util::column_index_mapping::ColIndexMapping;
19use risingwave_common::util::sort_util::ColumnOrder;
20
21use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
22use crate::optimizer::property::{
23    Distribution, FunctionalDependency, FunctionalDependencySet, MonotonicityMap, Order,
24    RequiredDist,
25};
26
27/// Extension trait for [`ColIndexMapping`] to rewrite frontend structures.
28#[easy_ext::ext(ColIndexMappingRewriteExt)]
29impl ColIndexMapping {
30    /// Rewrite the provided order's column index. It will try its best to give the most accurate
31    /// order. Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2)
32    /// Order(0,1,2) with mapping(0->1,2->0) will be rewritten to Order(1)
33    pub fn rewrite_provided_order(&self, order: &Order) -> Order {
34        let mut mapped_column_orders = vec![];
35        for column_order in &order.column_orders {
36            match self.try_map(column_order.column_index) {
37                Some(mapped_index) => mapped_column_orders
38                    .push(ColumnOrder::new(mapped_index, column_order.order_type)),
39                None => break,
40            }
41        }
42        Order {
43            column_orders: mapped_column_orders,
44        }
45    }
46
47    /// Rewrite the required order's field index. if it can't give a corresponding
48    /// required order after the column index mapping, it will return None.
49    /// Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2)
50    /// Order(0,1,2) with mapping(0->1,2->0) will return None
51    pub fn rewrite_required_order(&self, order: &Order) -> Option<Order> {
52        order
53            .column_orders
54            .iter()
55            .map(|o| {
56                self.try_map(o.column_index)
57                    .map(|mapped_index| ColumnOrder::new(mapped_index, o.order_type))
58            })
59            .collect::<Option<Vec<_>>>()
60            .map(|mapped_column_orders| Order {
61                column_orders: mapped_column_orders,
62            })
63    }
64
65    /// Rewrite the distribution key and will return None if **any** index of the key disappear
66    /// after the mapping.
67    pub fn rewrite_dist_key(&self, key: &[usize]) -> Option<Vec<usize>> {
68        self.try_map_all(key.iter().copied())
69    }
70
71    /// Rewrite the provided distribution's field index. It will try its best to give the most
72    /// accurate distribution.
73    /// HashShard(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to HashShard(1,0,2).
74    /// HashShard(0,1,2), with mapping(0->1,2->0) will be rewritten to `SomeShard`.
75    ///
76    /// For non-HashShard distribution, it will return the original distribution.
77    pub fn rewrite_provided_distribution(&self, dist: &Distribution) -> Distribution {
78        match dist {
79            Distribution::Single | Distribution::SomeShard | Distribution::Broadcast => {
80                dist.clone()
81            }
82
83            Distribution::HashShard(_) => match self.rewrite_dist_key(dist.dist_column_indices()) {
84                Some(mapped_dist_key) => Distribution::HashShard(mapped_dist_key),
85                None => Distribution::SomeShard,
86            },
87            Distribution::UpstreamHashShard(_, table_id) => {
88                match self.rewrite_dist_key(dist.dist_column_indices()) {
89                    Some(mapped_dist_key) => {
90                        Distribution::UpstreamHashShard(mapped_dist_key, *table_id)
91                    }
92                    None => Distribution::SomeShard,
93                }
94            }
95        }
96    }
97
98    /// Rewrite the required distribution's field index. if it can't give a corresponding
99    /// required distribution after the column index mapping, it will return None.
100    /// ShardByKey(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to ShardByKey(1,0,2).
101    /// ShardByKey(0,1,2), with mapping(0->1,2->0) will return ShardByKey(1,0).
102    /// ShardByKey(0,1), with mapping(2->0) will return `Any`.
103    pub fn rewrite_required_distribution(&self, dist: &RequiredDist) -> RequiredDist {
104        match dist {
105            RequiredDist::ShardByKey(keys) => {
106                assert!(!keys.is_clear());
107                let keys = self.rewrite_bitset(keys);
108                if keys.count_ones(..) == 0 {
109                    RequiredDist::Any
110                } else {
111                    RequiredDist::ShardByKey(keys)
112                }
113            }
114            RequiredDist::PhysicalDist(dist) => match dist {
115                Distribution::HashShard(keys) => {
116                    assert!(!keys.is_empty());
117                    let keys = self.rewrite_dist_key(keys);
118                    match keys {
119                        Some(keys) => RequiredDist::PhysicalDist(Distribution::HashShard(keys)),
120                        None => RequiredDist::Any,
121                    }
122                }
123                Distribution::UpstreamHashShard(keys, table_id) => {
124                    assert!(!keys.is_empty());
125                    let keys = self.rewrite_dist_key(keys);
126                    match keys {
127                        Some(keys) => RequiredDist::PhysicalDist(Distribution::UpstreamHashShard(
128                            keys, *table_id,
129                        )),
130                        None => RequiredDist::Any,
131                    }
132                }
133                Distribution::Single | Distribution::Broadcast | Distribution::SomeShard => {
134                    RequiredDist::PhysicalDist(dist.clone())
135                }
136            },
137            RequiredDist::Any => RequiredDist::Any,
138            RequiredDist::AnyShard => RequiredDist::AnyShard,
139        }
140    }
141
142    /// Rewrite the indices in a functional dependency.
143    ///
144    /// If some columns in the `from` side are removed, then this fd is no longer valid. For
145    /// example, for ABC --> D, it means that A, B, and C together can determine D. But if B is
146    /// removed, this fd is not valid. For this case, we will return [`None`]
147    ///
148    /// Additionally, If the `to` side of a functional dependency becomes empty after rewriting, it
149    /// means that this dependency is unneeded so we also return [`None`].
150    pub fn rewrite_functional_dependency(
151        &self,
152        fd: &FunctionalDependency,
153    ) -> Option<FunctionalDependency> {
154        let new_from = self.rewrite_bitset(fd.from());
155        let new_to = self.rewrite_bitset(fd.to());
156        if new_from.count_ones(..) != fd.from().count_ones(..) || new_to.is_clear() {
157            None
158        } else {
159            Some(FunctionalDependency::new(new_from, new_to))
160        }
161    }
162
163    /// Rewrite functional dependencies in `fd_set` one by one, using
164    /// `[ColIndexMapping::rewrite_functional_dependency]`.
165    ///
166    /// Note that this rewrite process handles each function dependency independently.
167    /// Relationships within function dependencies are not considered.
168    /// For example, if we have `fd_set` { AB --> C, A --> B }, and column B is removed.
169    /// The result would be an empty `fd_set`, rather than { A --> C }.
170    pub fn rewrite_functional_dependency_set(
171        &self,
172        fd_set: FunctionalDependencySet,
173    ) -> FunctionalDependencySet {
174        let mut new_fd_set = FunctionalDependencySet::new(self.target_size());
175        for i in fd_set.into_dependencies() {
176            if let Some(fd) = self.rewrite_functional_dependency(&i) {
177                new_fd_set.add_functional_dependency(fd);
178            }
179        }
180        new_fd_set
181    }
182
183    pub fn rewrite_bitset(&self, bitset: &FixedBitSet) -> FixedBitSet {
184        assert_eq!(bitset.len(), self.source_size());
185        let mut ret = FixedBitSet::with_capacity(self.target_size());
186        for i in bitset.ones() {
187            if let Some(i) = self.try_map(i) {
188                ret.insert(i);
189            }
190        }
191        ret
192    }
193
194    pub fn rewrite_monotonicity_map(&self, map: &MonotonicityMap) -> MonotonicityMap {
195        let mut new_map = MonotonicityMap::new();
196        for (i, monotonicity) in map.iter() {
197            if let Some(mapped_i) = self.try_map(i) {
198                new_map.insert(mapped_i, monotonicity);
199            }
200        }
201        new_map
202    }
203}
204
205impl ExprRewriter for ColIndexMapping {
206    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
207        InputRef::new(self.map(input_ref.index()), input_ref.return_type()).into()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_rewrite_fd() {
217        let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
218        let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
219        let fds_with_expected_res = vec![
220            (new_fd(&[0, 1], &[2, 3]), None),
221            (new_fd(&[2], &[0, 1]), None),
222            (
223                new_fd(&[1], &[0]),
224                Some(FunctionalDependency::with_indices(2, &[0], &[1])),
225            ),
226        ];
227        for (input, expected) in fds_with_expected_res {
228            assert_eq!(mapping.rewrite_functional_dependency(&input), expected);
229        }
230    }
231
232    #[test]
233    fn test_rewrite_fd_set() {
234        let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
235        let fd_set = FunctionalDependencySet::with_dependencies(
236            4,
237            vec![
238                // removed
239                new_fd(&[0, 1], &[2, 3]),
240                new_fd(&[2], &[0, 1]),
241                new_fd(&[0, 1, 2], &[3]),
242                // empty mappings will be removed
243                new_fd(&[], &[]),
244                new_fd(&[1], &[]),
245                // constant column mapping will be kept
246                new_fd(&[], &[0]),
247                // kept
248                new_fd(&[1], &[0]),
249            ],
250        );
251        let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
252        let result = mapping.rewrite_functional_dependency_set(fd_set);
253        let expected = FunctionalDependencySet::with_dependencies(
254            2,
255            vec![
256                FunctionalDependency::with_indices(2, &[], &[1]),
257                FunctionalDependency::with_indices(2, &[0], &[1]),
258            ],
259        );
260        assert_eq!(result, expected);
261    }
262}