risingwave_common/util/
column_index_mapping.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::max;
16use std::fmt::Debug;
17use std::vec;
18
19use itertools::Itertools;
20use risingwave_pb::catalog::PbColIndexMapping;
21
22/// `ColIndexMapping` is a partial mapping from usize to usize.
23///
24/// It is used in optimizer for transformation of column index.
25#[derive(Clone, PartialEq, Eq, Hash)]
26pub struct ColIndexMapping {
27    /// The size of the target space, i.e. target index is in the range `(0..target_size)`.
28    target_size: usize,
29    /// Each subscript is mapped to the corresponding element.
30    map: Vec<Option<usize>>,
31}
32
33impl ColIndexMapping {
34    /// Create a partial mapping which maps from the subscripts range `(0..map.len())` to
35    /// `(0..target_size)`. Each subscript is mapped to the corresponding element.
36    pub fn new(map: Vec<Option<usize>>, target_size: usize) -> Self {
37        if let Some(target_max) = map.iter().filter_map(|x| *x).max_by_key(|x| *x) {
38            assert!(
39                target_max < target_size,
40                "target_max: {}, target_size: {}",
41                target_max,
42                target_size
43            );
44        };
45        Self { target_size, map }
46    }
47
48    pub fn into_parts(self) -> (Vec<Option<usize>>, usize) {
49        (self.map, self.target_size)
50    }
51
52    pub fn to_parts(&self) -> (&[Option<usize>], usize) {
53        (&self.map, self.target_size)
54    }
55
56    pub fn put(&mut self, src: usize, tar: Option<usize>) {
57        assert!(src < self.source_size());
58        if let Some(tar) = tar {
59            assert!(tar < self.target_size());
60        }
61        self.map[src] = tar;
62    }
63
64    pub fn identity(size: usize) -> Self {
65        let map = (0..size).map(Some).collect();
66        Self::new(map, size)
67    }
68
69    pub fn is_identity(&self) -> bool {
70        if self.map.len() != self.target_size {
71            return false;
72        }
73        for (src, tar) in self.map.iter().enumerate() {
74            if let Some(tar_value) = tar
75                && src == *tar_value
76            {
77                continue;
78            } else {
79                return false;
80            }
81        }
82        true
83    }
84
85    pub fn identity_or_none(source_size: usize, target_size: usize) -> Self {
86        let map = (0..source_size)
87            .map(|i| if i < target_size { Some(i) } else { None })
88            .collect();
89        Self::new(map, target_size)
90    }
91
92    pub fn empty(source_size: usize, target_size: usize) -> Self {
93        let map = vec![None; source_size];
94        Self::new(map, target_size)
95    }
96
97    /// Create a partial mapping which maps range `(0..source_num)` to range
98    /// `(offset..offset+source_num)`.
99    ///
100    /// # Examples
101    ///
102    /// Positive offset:
103    ///
104    /// ```ignore
105    /// # use risingwave_frontend::utils::ColIndexMapping;
106    /// let mapping = ColIndexMapping::with_shift_offset(3, 3);
107    /// assert_eq!(mapping.map(0), 3);
108    /// assert_eq!(mapping.map(1), 4);
109    /// assert_eq!(mapping.map(2), 5);
110    /// ```
111    ///
112    /// Negative offset:
113    ///
114    ///  ```ignore
115    /// # use risingwave_frontend::utils::ColIndexMapping;
116    /// let mapping = ColIndexMapping::with_shift_offset(6, -3);
117    /// assert_eq!(mapping.try_map(0), None);
118    /// assert_eq!(mapping.try_map(1), None);
119    /// assert_eq!(mapping.try_map(2), None);
120    /// assert_eq!(mapping.map(3), 0);
121    /// assert_eq!(mapping.map(4), 1);
122    /// assert_eq!(mapping.map(5), 2);
123    /// assert_eq!(mapping.try_map(6), None);
124    /// ```
125    pub fn with_shift_offset(source_num: usize, offset: isize) -> Self {
126        let map = (0..source_num)
127            .map(|source| {
128                let target = source as isize + offset;
129                usize::try_from(target).ok()
130            })
131            .collect_vec();
132        let target_size = usize::try_from(source_num as isize + offset).unwrap();
133        Self::new(map, target_size)
134    }
135
136    /// Maps the smallest index to 0, the next smallest to 1, and so on.
137    ///
138    /// It is useful for column pruning.
139    ///
140    /// # Examples
141    ///
142    /// ```ignore
143    /// # use fixedbitset::FixedBitSet;
144    /// # use risingwave_frontend::utils::ColIndexMapping;
145    /// let mut remaining_cols = vec![1, 3];
146    /// let mapping = ColIndexMapping::with_remaining_columns(&remaining_cols, 4);
147    /// assert_eq!(mapping.map(1), 0);
148    /// assert_eq!(mapping.map(3), 1);
149    /// assert_eq!(mapping.try_map(0), None);
150    /// assert_eq!(mapping.try_map(2), None);
151    /// assert_eq!(mapping.try_map(4), None);
152    /// ```
153    pub fn with_remaining_columns(cols: &[usize], src_size: usize) -> Self {
154        let mut map = vec![None; src_size];
155        for (tar, &src) in cols.iter().enumerate() {
156            map[src] = Some(tar);
157        }
158        Self::new(map, cols.len())
159    }
160
161    // TODO(yuchao): isn't this the same as `with_remaining_columns`?
162    pub fn with_included_columns(cols: &[usize], src_size: usize) -> Self {
163        let mut map = vec![None; src_size];
164        for (tar, &src) in cols.iter().enumerate() {
165            if map[src].is_none() {
166                map[src] = Some(tar);
167            }
168        }
169        Self::new(map, cols.len())
170    }
171
172    /// Remove the given columns, and maps the remaining columns to a consecutive range starting
173    /// from 0.
174    ///
175    /// # Examples
176    ///
177    /// ```ignore
178    /// # use fixedbitset::FixedBitSet;
179    /// # use risingwave_frontend::utils::ColIndexMapping;
180    /// let mut removed_cols = vec![0, 2, 4];
181    /// let mapping = ColIndexMapping::with_removed_columns(&removed_cols, 5);
182    /// assert_eq!(mapping.map(1), 0);
183    /// assert_eq!(mapping.map(3), 1);
184    /// assert_eq!(mapping.try_map(0), None);
185    /// assert_eq!(mapping.try_map(2), None);
186    /// assert_eq!(mapping.try_map(4), None);
187    /// ```
188    pub fn with_removed_columns(cols: &[usize], src_size: usize) -> Self {
189        let cols = (0..src_size).filter(|x| !cols.contains(x)).collect_vec();
190        Self::with_remaining_columns(&cols, src_size)
191    }
192
193    #[must_use]
194    /// Compose column index mappings.
195    /// For example if this maps 0->5,
196    /// and `following` maps 5->1,
197    /// Then the composite has 0->5->1 => 0->1.
198    pub fn composite(&self, following: &Self) -> Self {
199        // debug!("composing {:?} and {:?}", self, following);
200        let mut map = self.map.clone();
201        for target in &mut map {
202            *target = target.and_then(|index| following.try_map(index));
203        }
204        Self::new(map, following.target_size())
205    }
206
207    pub fn clone_with_offset(&self, offset: usize) -> Self {
208        let mut map = self.map.clone();
209        for target in &mut map {
210            *target = target.and_then(|index| index.checked_add(offset));
211        }
212        Self::new(map, self.target_size() + offset)
213    }
214
215    /// Union two mapping, the result mapping `target_size` and source size will be the max size
216    /// of the two mappings.
217    ///
218    /// # Panics
219    ///
220    /// Will panic if a source appears in both to mapping
221    #[must_use]
222    pub fn union(&self, other: &Self) -> Self {
223        // debug!("union {:?} and {:?}", self, other);
224        let target_size = max(self.target_size(), other.target_size());
225        let source_size = max(self.source_size(), other.source_size());
226        let mut map = vec![None; source_size];
227        for (src, dst) in self.mapping_pairs() {
228            assert_eq!(map[src], None);
229            map[src] = Some(dst);
230        }
231        for (src, dst) in other.mapping_pairs() {
232            assert_eq!(map[src], None);
233            map[src] = Some(dst);
234        }
235        Self::new(map, target_size)
236    }
237
238    /// Inverse the mapping. If a target corresponds to more than one source, return `None`.
239    #[must_use]
240    pub fn inverse(&self) -> Option<Self> {
241        let mut map = vec![None; self.target_size()];
242        for (src, dst) in self.mapping_pairs() {
243            if map[dst].is_some() {
244                return None;
245            }
246            map[dst] = Some(src);
247        }
248        Some(Self::new(map, self.source_size()))
249    }
250
251    /// return iter of (src, dst) order by src
252    pub fn mapping_pairs(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
253        self.map
254            .iter()
255            .cloned()
256            .enumerate()
257            .filter_map(|(src, tar)| tar.map(|tar| (src, tar)))
258    }
259
260    /// Try mapping the source index to the target index.
261    pub fn try_map(&self, index: usize) -> Option<usize> {
262        *self.map.get(index)?
263    }
264
265    /// Try mapping all the source indices to the target indices. Returns `None` if any of the
266    /// indices is not mapped.
267    pub fn try_map_all(&self, indices: impl IntoIterator<Item = usize>) -> Option<Vec<usize>> {
268        indices.into_iter().map(|i| self.try_map(i)).collect()
269    }
270
271    /// # Panics
272    ///
273    /// Will panic if `index >= self.source_size()` or `index` is not mapped.
274    pub fn map(&self, index: usize) -> usize {
275        self.try_map(index).unwrap()
276    }
277
278    /// Returns the size of the target range. Target index is in the range `(0..target_size)`.
279    pub fn target_size(&self) -> usize {
280        self.target_size
281    }
282
283    /// Returns the size of the source range. Source index is in the range `(0..source_size)`.
284    pub fn source_size(&self) -> usize {
285        self.map.len()
286    }
287
288    pub fn is_empty(&self) -> bool {
289        self.target_size() == 0
290    }
291
292    pub fn is_injective(&self) -> bool {
293        let mut tar_exists = vec![false; self.target_size()];
294        for i in self.map.iter().flatten() {
295            if tar_exists[*i] {
296                return false;
297            }
298            tar_exists[*i] = true;
299        }
300        true
301    }
302}
303
304impl ColIndexMapping {
305    pub fn to_protobuf(&self) -> PbColIndexMapping {
306        PbColIndexMapping {
307            target_size: self.target_size as u64,
308            map: self
309                .map
310                .iter()
311                .map(|x| x.map_or(-1, |x| x as i64))
312                .collect(),
313        }
314    }
315
316    pub fn from_protobuf(prost: &PbColIndexMapping) -> ColIndexMapping {
317        ColIndexMapping {
318            target_size: prost.target_size as usize,
319            map: prost.map.iter().map(|&x| x.try_into().ok()).collect(),
320        }
321    }
322}
323
324impl Debug for ColIndexMapping {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(
327            f,
328            "ColIndexMapping(source_size:{}, target_size:{}, mapping:{})",
329            self.source_size(),
330            self.target_size(),
331            self.mapping_pairs()
332                .map(|(src, dst)| format!("{}->{}", src, dst))
333                .join(",")
334        )
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_shift_0() {
344        let mapping = ColIndexMapping::with_shift_offset(3, 0);
345        assert_eq!(mapping.map(0), 0);
346        assert_eq!(mapping.map(1), 1);
347        assert_eq!(mapping.map(2), 2);
348        assert_eq!(mapping.try_map(3), None);
349        assert_eq!(mapping.try_map(4), None);
350    }
351
352    #[test]
353    fn test_shift_0_source() {
354        let mapping = ColIndexMapping::with_shift_offset(0, 3);
355        assert_eq!(mapping.target_size(), 3);
356    }
357
358    #[test]
359    fn test_composite() {
360        let add_mapping = ColIndexMapping::with_shift_offset(3, 3);
361        let remaining_cols = vec![3, 5];
362        let col_prune_mapping = ColIndexMapping::with_remaining_columns(&remaining_cols, 6);
363        let composite = add_mapping.composite(&col_prune_mapping);
364        assert_eq!(composite.map(0), 0); // 0+3 = 3, 3 -> 0
365        assert_eq!(composite.try_map(1), None);
366        assert_eq!(composite.map(2), 1); // 2+3 = 5, 5 -> 1
367    }
368
369    #[test]
370    fn test_identity() {
371        let mapping = ColIndexMapping::identity(10);
372        assert!(mapping.is_identity());
373    }
374}