risingwave_common/util/
column_index_mapping.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::max;
16use std::fmt::Debug;
17use std::vec;
18
19use itertools::Itertools;
20use risingwave_pb::catalog::PbColIndexMapping;
21use risingwave_pb::stream_plan::DispatchStrategy;
22
23/// `ColIndexMapping` is a partial mapping from usize to usize.
24///
25/// It is used in optimizer for transformation of column index.
26#[derive(Clone, PartialEq, Eq, Hash)]
27pub struct ColIndexMapping {
28    /// The size of the target space, i.e. target index is in the range `(0..target_size)`.
29    target_size: usize,
30    /// Each subscript is mapped to the corresponding element.
31    map: Vec<Option<usize>>,
32}
33
34impl ColIndexMapping {
35    /// Create a partial mapping which maps from the subscripts range `(0..map.len())` to
36    /// `(0..target_size)`. Each subscript is mapped to the corresponding element.
37    pub fn new(map: Vec<Option<usize>>, target_size: usize) -> Self {
38        if let Some(target_max) = map.iter().filter_map(|x| *x).max_by_key(|x| *x) {
39            assert!(
40                target_max < target_size,
41                "target_max: {}, target_size: {}",
42                target_max,
43                target_size
44            );
45        };
46        Self { target_size, map }
47    }
48
49    pub fn into_parts(self) -> (Vec<Option<usize>>, usize) {
50        (self.map, self.target_size)
51    }
52
53    pub fn to_parts(&self) -> (&[Option<usize>], usize) {
54        (&self.map, self.target_size)
55    }
56
57    pub fn put(&mut self, src: usize, tar: Option<usize>) {
58        assert!(src < self.source_size());
59        if let Some(tar) = tar {
60            assert!(tar < self.target_size());
61        }
62        self.map[src] = tar;
63    }
64
65    pub fn identity(size: usize) -> Self {
66        let map = (0..size).map(Some).collect();
67        Self::new(map, size)
68    }
69
70    pub fn is_identity(&self) -> bool {
71        if self.map.len() != self.target_size {
72            return false;
73        }
74        for (src, tar) in self.map.iter().enumerate() {
75            if let Some(tar_value) = tar
76                && src == *tar_value
77            {
78                continue;
79            } else {
80                return false;
81            }
82        }
83        true
84    }
85
86    pub fn identity_or_none(source_size: usize, target_size: usize) -> Self {
87        let map = (0..source_size)
88            .map(|i| if i < target_size { Some(i) } else { None })
89            .collect();
90        Self::new(map, target_size)
91    }
92
93    pub fn empty(source_size: usize, target_size: usize) -> Self {
94        let map = vec![None; source_size];
95        Self::new(map, target_size)
96    }
97
98    /// Create a partial mapping which maps range `(0..source_num)` to range
99    /// `(offset..offset+source_num)`.
100    ///
101    /// # Examples
102    ///
103    /// Positive offset:
104    ///
105    /// ```ignore
106    /// # use risingwave_frontend::utils::ColIndexMapping;
107    /// let mapping = ColIndexMapping::with_shift_offset(3, 3);
108    /// assert_eq!(mapping.map(0), 3);
109    /// assert_eq!(mapping.map(1), 4);
110    /// assert_eq!(mapping.map(2), 5);
111    /// ```
112    ///
113    /// Negative offset:
114    ///
115    ///  ```ignore
116    /// # use risingwave_frontend::utils::ColIndexMapping;
117    /// let mapping = ColIndexMapping::with_shift_offset(6, -3);
118    /// assert_eq!(mapping.try_map(0), None);
119    /// assert_eq!(mapping.try_map(1), None);
120    /// assert_eq!(mapping.try_map(2), None);
121    /// assert_eq!(mapping.map(3), 0);
122    /// assert_eq!(mapping.map(4), 1);
123    /// assert_eq!(mapping.map(5), 2);
124    /// assert_eq!(mapping.try_map(6), None);
125    /// ```
126    pub fn with_shift_offset(source_num: usize, offset: isize) -> Self {
127        let map = (0..source_num)
128            .map(|source| {
129                let target = source as isize + offset;
130                usize::try_from(target).ok()
131            })
132            .collect_vec();
133        let target_size = usize::try_from(source_num as isize + offset).unwrap();
134        Self::new(map, target_size)
135    }
136
137    /// Maps the smallest index to 0, the next smallest to 1, and so on.
138    ///
139    /// It is useful for column pruning.
140    ///
141    /// # Examples
142    ///
143    /// ```ignore
144    /// # use fixedbitset::FixedBitSet;
145    /// # use risingwave_frontend::utils::ColIndexMapping;
146    /// let mut remaining_cols = vec![1, 3];
147    /// let mapping = ColIndexMapping::with_remaining_columns(&remaining_cols, 4);
148    /// assert_eq!(mapping.map(1), 0);
149    /// assert_eq!(mapping.map(3), 1);
150    /// assert_eq!(mapping.try_map(0), None);
151    /// assert_eq!(mapping.try_map(2), None);
152    /// assert_eq!(mapping.try_map(4), None);
153    /// ```
154    pub fn with_remaining_columns(cols: &[usize], src_size: usize) -> Self {
155        let mut map = vec![None; src_size];
156        for (tar, &src) in cols.iter().enumerate() {
157            map[src] = Some(tar);
158        }
159        Self::new(map, cols.len())
160    }
161
162    // TODO(yuchao): isn't this the same as `with_remaining_columns`?
163    pub fn with_included_columns(cols: &[usize], src_size: usize) -> Self {
164        let mut map = vec![None; src_size];
165        for (tar, &src) in cols.iter().enumerate() {
166            if map[src].is_none() {
167                map[src] = Some(tar);
168            }
169        }
170        Self::new(map, cols.len())
171    }
172
173    /// Remove the given columns, and maps the remaining columns to a consecutive range starting
174    /// from 0.
175    ///
176    /// # Examples
177    ///
178    /// ```ignore
179    /// # use fixedbitset::FixedBitSet;
180    /// # use risingwave_frontend::utils::ColIndexMapping;
181    /// let mut removed_cols = vec![0, 2, 4];
182    /// let mapping = ColIndexMapping::with_removed_columns(&removed_cols, 5);
183    /// assert_eq!(mapping.map(1), 0);
184    /// assert_eq!(mapping.map(3), 1);
185    /// assert_eq!(mapping.try_map(0), None);
186    /// assert_eq!(mapping.try_map(2), None);
187    /// assert_eq!(mapping.try_map(4), None);
188    /// ```
189    pub fn with_removed_columns(cols: &[usize], src_size: usize) -> Self {
190        let cols = (0..src_size).filter(|x| !cols.contains(x)).collect_vec();
191        Self::with_remaining_columns(&cols, src_size)
192    }
193
194    #[must_use]
195    /// Compose column index mappings.
196    /// For example if this maps 0->5,
197    /// and `following` maps 5->1,
198    /// Then the composite has 0->5->1 => 0->1.
199    pub fn composite(&self, following: &Self) -> Self {
200        // debug!("composing {:?} and {:?}", self, following);
201        let mut map = self.map.clone();
202        for target in &mut map {
203            *target = target.and_then(|index| following.try_map(index));
204        }
205        Self::new(map, following.target_size())
206    }
207
208    pub fn clone_with_offset(&self, offset: usize) -> Self {
209        let mut map = self.map.clone();
210        for target in &mut map {
211            *target = target.and_then(|index| index.checked_add(offset));
212        }
213        Self::new(map, self.target_size() + offset)
214    }
215
216    /// Union two mapping, the result mapping `target_size` and source size will be the max size
217    /// of the two mappings.
218    ///
219    /// # Panics
220    ///
221    /// Will panic if a source appears in both to mapping
222    #[must_use]
223    pub fn union(&self, other: &Self) -> Self {
224        // debug!("union {:?} and {:?}", self, other);
225        let target_size = max(self.target_size(), other.target_size());
226        let source_size = max(self.source_size(), other.source_size());
227        let mut map = vec![None; source_size];
228        for (src, dst) in self.mapping_pairs() {
229            assert_eq!(map[src], None);
230            map[src] = Some(dst);
231        }
232        for (src, dst) in other.mapping_pairs() {
233            assert_eq!(map[src], None);
234            map[src] = Some(dst);
235        }
236        Self::new(map, target_size)
237    }
238
239    /// Inverse the mapping. If a target corresponds to more than one source, return `None`.
240    #[must_use]
241    pub fn inverse(&self) -> Option<Self> {
242        let mut map = vec![None; self.target_size()];
243        for (src, dst) in self.mapping_pairs() {
244            if map[dst].is_some() {
245                return None;
246            }
247            map[dst] = Some(src);
248        }
249        Some(Self::new(map, self.source_size()))
250    }
251
252    /// return iter of (src, dst) order by src
253    pub fn mapping_pairs(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
254        self.map
255            .iter()
256            .cloned()
257            .enumerate()
258            .filter_map(|(src, tar)| tar.map(|tar| (src, tar)))
259    }
260
261    /// Try mapping the source index to the target index.
262    pub fn try_map(&self, index: usize) -> Option<usize> {
263        *self.map.get(index)?
264    }
265
266    /// Try mapping all the source indices to the target indices. Returns `None` if any of the
267    /// indices is not mapped.
268    pub fn try_map_all(&self, indices: impl IntoIterator<Item = usize>) -> Option<Vec<usize>> {
269        indices.into_iter().map(|i| self.try_map(i)).collect()
270    }
271
272    /// # Panics
273    ///
274    /// Will panic if `index >= self.source_size()` or `index` is not mapped.
275    pub fn map(&self, index: usize) -> usize {
276        self.try_map(index).unwrap()
277    }
278
279    /// Returns the size of the target range. Target index is in the range `(0..target_size)`.
280    pub fn target_size(&self) -> usize {
281        self.target_size
282    }
283
284    /// Returns the size of the source range. Source index is in the range `(0..source_size)`.
285    pub fn source_size(&self) -> usize {
286        self.map.len()
287    }
288
289    pub fn is_empty(&self) -> bool {
290        self.target_size() == 0
291    }
292
293    pub fn is_injective(&self) -> bool {
294        let mut tar_exists = vec![false; self.target_size()];
295        for i in self.map.iter().flatten() {
296            if tar_exists[*i] {
297                return false;
298            }
299            tar_exists[*i] = true;
300        }
301        true
302    }
303}
304
305impl ColIndexMapping {
306    pub fn to_protobuf(&self) -> PbColIndexMapping {
307        PbColIndexMapping {
308            target_size: self.target_size as u64,
309            map: self
310                .map
311                .iter()
312                .map(|x| x.map_or(-1, |x| x as i64))
313                .collect(),
314        }
315    }
316
317    pub fn from_protobuf(prost: &PbColIndexMapping) -> ColIndexMapping {
318        ColIndexMapping {
319            target_size: prost.target_size as usize,
320            map: prost.map.iter().map(|&x| x.try_into().ok()).collect(),
321        }
322    }
323}
324
325impl ColIndexMapping {
326    /// Rewrite the dist-key indices and output indices in the given dispatch strategy. Returns
327    /// `None` if any of the indices is not mapped to the target.
328    pub fn rewrite_dispatch_strategy(
329        &self,
330        strategy: &DispatchStrategy,
331    ) -> Option<DispatchStrategy> {
332        let map = |index: &[u32]| -> Option<Vec<u32>> {
333            index
334                .iter()
335                .map(|i| self.try_map(*i as usize).map(|i| i as u32))
336                .collect()
337        };
338
339        Some(DispatchStrategy {
340            r#type: strategy.r#type,
341            dist_key_indices: map(&strategy.dist_key_indices)?,
342            output_indices: map(&strategy.output_indices)?,
343        })
344    }
345}
346
347impl Debug for ColIndexMapping {
348    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349        write!(
350            f,
351            "ColIndexMapping(source_size:{}, target_size:{}, mapping:{})",
352            self.source_size(),
353            self.target_size(),
354            self.mapping_pairs()
355                .map(|(src, dst)| format!("{}->{}", src, dst))
356                .join(",")
357        )
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_shift_0() {
367        let mapping = ColIndexMapping::with_shift_offset(3, 0);
368        assert_eq!(mapping.map(0), 0);
369        assert_eq!(mapping.map(1), 1);
370        assert_eq!(mapping.map(2), 2);
371        assert_eq!(mapping.try_map(3), None);
372        assert_eq!(mapping.try_map(4), None);
373    }
374
375    #[test]
376    fn test_shift_0_source() {
377        let mapping = ColIndexMapping::with_shift_offset(0, 3);
378        assert_eq!(mapping.target_size(), 3);
379    }
380
381    #[test]
382    fn test_composite() {
383        let add_mapping = ColIndexMapping::with_shift_offset(3, 3);
384        let remaining_cols = vec![3, 5];
385        let col_prune_mapping = ColIndexMapping::with_remaining_columns(&remaining_cols, 6);
386        let composite = add_mapping.composite(&col_prune_mapping);
387        assert_eq!(composite.map(0), 0); // 0+3 = 3, 3 -> 0
388        assert_eq!(composite.try_map(1), None);
389        assert_eq!(composite.map(2), 1); // 2+3 = 5, 5 -> 1
390    }
391
392    #[test]
393    fn test_identity() {
394        let mapping = ColIndexMapping::identity(10);
395        assert!(mapping.is_identity());
396    }
397}