risingwave_frontend/utils/
column_index_mapping.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::vec;
16
17use fixedbitset::FixedBitSet;
18pub use risingwave_common::util::column_index_mapping::ColIndexMapping;
19use risingwave_common::util::sort_util::ColumnOrder;
20
21use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
22use crate::optimizer::property::{
23    Distribution, FunctionalDependency, FunctionalDependencySet, MonotonicityMap, Order,
24    RequiredDist,
25};
26
27/// Extension trait for [`ColIndexMapping`] to rewrite frontend structures.
28#[easy_ext::ext(ColIndexMappingRewriteExt)]
29impl ColIndexMapping {
30    /// Rewrite the provided order's column index. It will try its best to give the most accurate
31    /// order. Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2)
32    /// Order(0,1,2) with mapping(0->1,2->0) will be rewritten to Order(1)
33    pub fn rewrite_provided_order(&self, order: &Order) -> Order {
34        let mut mapped_column_orders = vec![];
35        for column_order in &order.column_orders {
36            match self.try_map(column_order.column_index) {
37                Some(mapped_index) => mapped_column_orders
38                    .push(ColumnOrder::new(mapped_index, column_order.order_type)),
39                None => break,
40            }
41        }
42        Order {
43            column_orders: mapped_column_orders,
44        }
45    }
46
47    /// Rewrite the required order's field index. if it can't give a corresponding
48    /// required order after the column index mapping, it will return None.
49    /// Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2)
50    /// Order(0,1,2) with mapping(0->1,2->0) will return None
51    pub fn rewrite_required_order(&self, order: &Order) -> Option<Order> {
52        order
53            .column_orders
54            .iter()
55            .map(|o| {
56                self.try_map(o.column_index)
57                    .map(|mapped_index| ColumnOrder::new(mapped_index, o.order_type))
58            })
59            .collect::<Option<Vec<_>>>()
60            .map(|mapped_column_orders| Order {
61                column_orders: mapped_column_orders,
62            })
63    }
64
65    /// Rewrite the distribution key and will return None if **any** index of the key disappear
66    /// after the mapping.
67    pub fn rewrite_dist_key(&self, key: &[usize]) -> Option<Vec<usize>> {
68        self.try_map_all(key.iter().copied())
69    }
70
71    /// Rewrite the provided distribution's field index. It will try its best to give the most
72    /// accurate distribution.
73    /// HashShard(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to HashShard(1,0,2).
74    /// HashShard(0,1,2), with mapping(0->1,2->0) will be rewritten to `SomeShard`.
75    pub fn rewrite_provided_distribution(&self, dist: &Distribution) -> Distribution {
76        let mapped_dist_key = self.rewrite_dist_key(dist.dist_column_indices());
77
78        match (mapped_dist_key, dist) {
79            (None, Distribution::HashShard(_)) | (None, Distribution::UpstreamHashShard(_, _)) => {
80                Distribution::SomeShard
81            }
82            (Some(mapped_dist_key), Distribution::HashShard(_)) => {
83                Distribution::HashShard(mapped_dist_key)
84            }
85            (Some(mapped_dist_key), Distribution::UpstreamHashShard(_, table_id)) => {
86                Distribution::UpstreamHashShard(mapped_dist_key, *table_id)
87            }
88            _ => {
89                assert!(dist.dist_column_indices().is_empty());
90                dist.clone()
91            }
92        }
93    }
94
95    /// Rewrite the required distribution's field index. if it can't give a corresponding
96    /// required distribution after the column index mapping, it will return None.
97    /// ShardByKey(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to ShardByKey(1,0,2).
98    /// ShardByKey(0,1,2), with mapping(0->1,2->0) will return ShardByKey(1,0).
99    /// ShardByKey(0,1), with mapping(2->0) will return `Any`.
100    pub fn rewrite_required_distribution(&self, dist: &RequiredDist) -> RequiredDist {
101        match dist {
102            RequiredDist::ShardByKey(keys) => {
103                assert!(!keys.is_clear());
104                let keys = self.rewrite_bitset(keys);
105                if keys.count_ones(..) == 0 {
106                    RequiredDist::Any
107                } else {
108                    RequiredDist::ShardByKey(keys)
109                }
110            }
111            RequiredDist::PhysicalDist(dist) => match dist {
112                Distribution::HashShard(keys) => {
113                    assert!(!keys.is_empty());
114                    let keys = self.rewrite_dist_key(keys);
115                    match keys {
116                        Some(keys) => RequiredDist::PhysicalDist(Distribution::HashShard(keys)),
117                        None => RequiredDist::Any,
118                    }
119                }
120                Distribution::UpstreamHashShard(keys, table_id) => {
121                    assert!(!keys.is_empty());
122                    let keys = self.rewrite_dist_key(keys);
123                    match keys {
124                        Some(keys) => RequiredDist::PhysicalDist(Distribution::UpstreamHashShard(
125                            keys, *table_id,
126                        )),
127                        None => RequiredDist::Any,
128                    }
129                }
130                Distribution::Single => RequiredDist::PhysicalDist(Distribution::Single),
131                Distribution::Broadcast => RequiredDist::PhysicalDist(Distribution::Broadcast),
132                Distribution::SomeShard => RequiredDist::PhysicalDist(Distribution::SomeShard),
133            },
134            RequiredDist::Any => RequiredDist::Any,
135            RequiredDist::AnyShard => RequiredDist::AnyShard,
136        }
137    }
138
139    /// Rewrite the indices in a functional dependency.
140    ///
141    /// If some columns in the `from` side are removed, then this fd is no longer valid. For
142    /// example, for ABC --> D, it means that A, B, and C together can determine C. But if B is
143    /// removed, this fd is not valid. For this case, we will return [`None`]
144    ///
145    /// Additionally, If the `to` side of a functional dependency becomes empty after rewriting, it
146    /// means that this dependency is unneeded so we also return [`None`].
147    pub fn rewrite_functional_dependency(
148        &self,
149        fd: &FunctionalDependency,
150    ) -> Option<FunctionalDependency> {
151        let new_from = self.rewrite_bitset(fd.from());
152        let new_to = self.rewrite_bitset(fd.to());
153        if new_from.count_ones(..) != fd.from().count_ones(..) || new_to.is_clear() {
154            None
155        } else {
156            Some(FunctionalDependency::new(new_from, new_to))
157        }
158    }
159
160    /// Rewrite functional dependencies in `fd_set` one by one, using
161    /// `[ColIndexMapping::rewrite_functional_dependency]`.
162    ///
163    /// Note that this rewrite process handles each function dependency independently.
164    /// Relationships within function dependencies are not considered.
165    /// For example, if we have `fd_set` { AB --> C, A --> B }, and column B is removed.
166    /// The result would be an empty `fd_set`, rather than { A --> C }.
167    pub fn rewrite_functional_dependency_set(
168        &self,
169        fd_set: FunctionalDependencySet,
170    ) -> FunctionalDependencySet {
171        let mut new_fd_set = FunctionalDependencySet::new(self.target_size());
172        for i in fd_set.into_dependencies() {
173            if let Some(fd) = self.rewrite_functional_dependency(&i) {
174                new_fd_set.add_functional_dependency(fd);
175            }
176        }
177        new_fd_set
178    }
179
180    pub fn rewrite_bitset(&self, bitset: &FixedBitSet) -> FixedBitSet {
181        assert_eq!(bitset.len(), self.source_size());
182        let mut ret = FixedBitSet::with_capacity(self.target_size());
183        for i in bitset.ones() {
184            if let Some(i) = self.try_map(i) {
185                ret.insert(i);
186            }
187        }
188        ret
189    }
190
191    pub fn rewrite_monotonicity_map(&self, map: &MonotonicityMap) -> MonotonicityMap {
192        let mut new_map = MonotonicityMap::new();
193        for (i, monotonicity) in map.iter() {
194            if let Some(mapped_i) = self.try_map(i) {
195                new_map.insert(mapped_i, monotonicity);
196            }
197        }
198        new_map
199    }
200}
201
202impl ExprRewriter for ColIndexMapping {
203    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
204        InputRef::new(self.map(input_ref.index()), input_ref.return_type()).into()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_rewrite_fd() {
214        let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
215        let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
216        let fds_with_expected_res = vec![
217            (new_fd(&[0, 1], &[2, 3]), None),
218            (new_fd(&[2], &[0, 1]), None),
219            (
220                new_fd(&[1], &[0]),
221                Some(FunctionalDependency::with_indices(2, &[0], &[1])),
222            ),
223        ];
224        for (input, expected) in fds_with_expected_res {
225            assert_eq!(mapping.rewrite_functional_dependency(&input), expected);
226        }
227    }
228
229    #[test]
230    fn test_rewrite_fd_set() {
231        let new_fd = |from, to| FunctionalDependency::with_indices(4, from, to);
232        let fd_set = FunctionalDependencySet::with_dependencies(
233            4,
234            vec![
235                // removed
236                new_fd(&[0, 1], &[2, 3]),
237                new_fd(&[2], &[0, 1]),
238                new_fd(&[0, 1, 2], &[3]),
239                // empty mappings will be removed
240                new_fd(&[], &[]),
241                new_fd(&[1], &[]),
242                // constant column mapping will be kept
243                new_fd(&[], &[0]),
244                // kept
245                new_fd(&[1], &[0]),
246            ],
247        );
248        let mapping = ColIndexMapping::with_remaining_columns(&[1, 0], 4);
249        let result = mapping.rewrite_functional_dependency_set(fd_set);
250        let expected = FunctionalDependencySet::with_dependencies(
251            2,
252            vec![
253                FunctionalDependency::with_indices(2, &[], &[1]),
254                FunctionalDependency::with_indices(2, &[0], &[1]),
255            ],
256        );
257        assert_eq!(result, expected);
258    }
259}