risingwave_frontend/optimizer/property/
distribution.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//!   "A -> B" represent A satisfies B
16//!                                 x
17//!  only as a required property    x  can used as both required
18//!                                 x  and provided property
19//!                                 x
20//!            ┌───┐                x┌──────┐
21//!            │Any◄─────────────────┤single│
22//!            └─▲─┘                x└──────┘
23//!              │                  x
24//!              │                  x
25//!              │                  x
26//!          ┌───┴────┐             x┌──────────┐
27//!          │AnyShard◄──────────────┤SomeShard │
28//!          └───▲────┘             x└──────────┘
29//!              │                  x
30//!          ┌───┴───────────┐      x┌──────────────┐ ┌──────────────┐
31//!          │ShardByKey(a,b)◄───┬───┤HashShard(a,b)│ │HashShard(b,a)│
32//!          └───▲──▲────────┘   │  x└──────────────┘ └┬─────────────┘
33//!              │  │            │  x                  │
34//!              │  │            └─────────────────────┘
35//!              │  │               x
36//!              │ ┌┴────────────┐  x┌────────────┐
37//!              │ │ShardByKey(a)◄───┤HashShard(a)│
38//!              │ └─────────────┘  x└────────────┘
39//!              │                  x
40//!             ┌┴────────────┐     x┌────────────┐
41//!             │ShardByKey(b)◄──────┤HashShard(b)│
42//!             └─────────────┘     x└────────────┘
43//!                                 x
44//!                                 x
45use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57    ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66/// the distribution property provided by a operator.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69    /// There is only one partition. All records are placed on it.
70    ///
71    /// Note: singleton will not be enforced automatically.
72    /// It's set in `crate::stream_fragmenter::build_fragment`,
73    /// by setting `requires_singleton` manually.
74    Single,
75    /// Records are sharded into partitions, and satisfy the `AnyShard` but without any guarantee
76    /// about their placement rules.
77    SomeShard,
78    /// Records are sharded into partitions based on the hash value of some columns, which means
79    /// the records with the same hash values must be on the same partition.
80    /// `usize` is the index of column used as the distribution key.
81    HashShard(Vec<usize>),
82    /// A special kind of provided distribution which is almost the same as
83    /// [`Distribution::HashShard`], but may have different vnode mapping.
84    ///
85    /// It exists because the upstream MV can be scaled independently. So we use
86    /// `UpstreamHashShard` to **force an exchange to be inserted**.
87    ///
88    /// Alternatively, [`Distribution::SomeShard`] can also be used to insert an exchange, but
89    /// `UpstreamHashShard` contains distribution keys, which might be useful in some cases, e.g.,
90    /// two-phase Agg. It also satisfies [`RequiredDist::ShardByKey`].
91    ///
92    /// `TableId` is used to represent the data distribution(`vnode_mapping`) of this
93    /// `UpstreamHashShard`. The scheduler can fetch `TableId`'s corresponding `vnode_mapping` to do
94    /// shuffle.
95    UpstreamHashShard(Vec<usize>, TableId),
96    /// Records are available on all downstream shards.
97    Broadcast,
98}
99
100/// the distribution property requirement.
101#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103    /// with any distribution
104    Any,
105    /// records are shard on partitions, which means every record should belong to a partition
106    AnyShard,
107    /// records are shard on partitions based on some keys(order-irrelevance, ShardByKey({a,b}) is
108    /// equivalent with ShardByKey({b,a})), which means the records with same keys must be on
109    /// the same partition, as required property only.
110    ShardByKey(FixedBitSet),
111    /// must be the same with the physical distribution
112    PhysicalDist(Distribution),
113}
114
115impl Distribution {
116    pub fn to_prost(
117        &self,
118        output_count: u32,
119        catalog_reader: &CatalogReader,
120        worker_node_manager: &WorkerNodeSelector,
121    ) -> Result<ExchangeInfo> {
122        let exchange_info = ExchangeInfo {
123            mode: match self {
124                Distribution::Single => DistributionMode::Single,
125                Distribution::HashShard(_) => DistributionMode::Hash,
126                // TODO: add round robin DistributionMode
127                Distribution::SomeShard => DistributionMode::Single,
128                Distribution::Broadcast => DistributionMode::Broadcast,
129                Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
130            } as i32,
131            distribution: match self {
132                Distribution::Single => None,
133                Distribution::HashShard(key) => {
134                    assert!(
135                        !key.is_empty(),
136                        "hash key should not be empty, use `Single` instead"
137                    );
138                    Some(PbDistribution::HashInfo(HashInfo {
139                        output_count,
140                        key: key.iter().map(|num| *num as u32).collect(),
141                    }))
142                }
143                // TODO: add round robin distribution
144                Distribution::SomeShard => None,
145                Distribution::Broadcast => None,
146                Distribution::UpstreamHashShard(key, table_id) => {
147                    assert!(
148                        !key.is_empty(),
149                        "hash key should not be empty, use `Single` instead"
150                    );
151
152                    let vnode_mapping = worker_node_manager
153                        .fragment_mapping(Self::get_fragment_id(catalog_reader, table_id)?)?;
154
155                    let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
156                        .iter_unique()
157                        .enumerate()
158                        .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
159                        .collect();
160
161                    Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
162                        vmap: vnode_mapping
163                            .iter()
164                            .map(|id| worker_slot_to_id_map[&id])
165                            .collect_vec(),
166                        key: key.iter().map(|num| *num as u32).collect(),
167                    }))
168                }
169            },
170        };
171        Ok(exchange_info)
172    }
173
174    /// check if the distribution satisfies other required distribution
175    pub fn satisfies(&self, required: &RequiredDist) -> bool {
176        match required {
177            RequiredDist::Any => true,
178            RequiredDist::AnyShard => {
179                matches!(
180                    self,
181                    Distribution::SomeShard
182                        | Distribution::HashShard(_)
183                        | Distribution::UpstreamHashShard(_, _)
184                        | Distribution::Broadcast
185                )
186            }
187            RequiredDist::ShardByKey(required_key) => match self {
188                Distribution::HashShard(hash_key)
189                | Distribution::UpstreamHashShard(hash_key, _) => {
190                    hash_key.iter().all(|idx| required_key.contains(*idx))
191                }
192                _ => false,
193            },
194            RequiredDist::PhysicalDist(other) => self == other,
195        }
196    }
197
198    /// Get distribution column indices.
199    ///
200    /// Panics if the distribution is not `HashShard`, `UpstreamHashShard` or `Single`.
201    pub fn dist_column_indices(&self) -> &[usize] {
202        match self {
203            Distribution::Single => &[],
204            Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => dists,
205            Distribution::SomeShard | Distribution::Broadcast => {
206                panic!("cannot obtain distribution columns for {self:?}")
207            }
208        }
209    }
210
211    #[inline(always)]
212    fn get_fragment_id(catalog_reader: &CatalogReader, table_id: &TableId) -> Result<FragmentId> {
213        catalog_reader
214            .read_guard()
215            .get_any_table_by_id(table_id)
216            .map(|table| table.fragment_id)
217            .map_err(Into::into)
218    }
219}
220
221impl fmt::Display for Distribution {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        f.write_str("[")?;
224        match self {
225            Self::Single => f.write_str("Single")?,
226            Self::SomeShard => f.write_str("SomeShard")?,
227            Self::Broadcast => f.write_str("Broadcast")?,
228            Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
229                for key in vec {
230                    std::fmt::Debug::fmt(&key, f)?;
231                }
232            }
233        }
234        f.write_str("]")
235    }
236}
237
238pub struct DistributionDisplay<'a> {
239    pub distribution: &'a Distribution,
240    pub input_schema: &'a Schema,
241}
242
243impl DistributionDisplay<'_> {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        let that = self.distribution;
246        match that {
247            Distribution::Single => f.write_str("Single"),
248            Distribution::SomeShard => f.write_str("SomeShard"),
249            Distribution::Broadcast => f.write_str("Broadcast"),
250            Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
251                if let Distribution::HashShard(_) = that {
252                    f.write_str("HashShard(")?;
253                } else {
254                    f.write_str("UpstreamHashShard(")?;
255                }
256                for (pos, key) in vec.iter().copied().with_position() {
257                    std::fmt::Debug::fmt(
258                        &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
259                        f,
260                    )?;
261                    match pos {
262                        itertools::Position::First | itertools::Position::Middle => {
263                            f.write_str(", ")?;
264                        }
265                        _ => {}
266                    }
267                }
268                f.write_str(")")
269            }
270        }
271    }
272}
273
274impl fmt::Debug for DistributionDisplay<'_> {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        self.fmt(f)
277    }
278}
279
280impl fmt::Display for DistributionDisplay<'_> {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        self.fmt(f)
283    }
284}
285
286impl RequiredDist {
287    pub fn single() -> Self {
288        Self::PhysicalDist(Distribution::Single)
289    }
290
291    pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
292        let mut cols = FixedBitSet::with_capacity(tot_col_num);
293        for i in key {
294            cols.insert(*i);
295        }
296        assert!(!cols.is_clear());
297        Self::ShardByKey(cols)
298    }
299
300    pub fn hash_shard(key: &[usize]) -> Self {
301        assert!(!key.is_empty());
302        Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
303    }
304
305    pub fn batch_enforce_if_not_satisfies(
306        &self,
307        mut plan: BatchPlanRef,
308        required_order: &Order,
309    ) -> Result<BatchPlanRef> {
310        plan = required_order.enforce_if_not_satisfies(plan)?;
311        if !plan.distribution().satisfies(self) {
312            Ok(self.batch_enforce(plan, required_order))
313        } else {
314            Ok(plan)
315        }
316    }
317
318    pub fn streaming_enforce_if_not_satisfies(&self, plan: StreamPlanRef) -> Result<StreamPlanRef> {
319        if !plan.distribution().satisfies(self) {
320            Ok(self.stream_enforce(plan))
321        } else {
322            Ok(plan)
323        }
324    }
325
326    pub fn no_shuffle(plan: StreamPlanRef) -> StreamPlanRef {
327        StreamExchange::new_no_shuffle(plan).into()
328    }
329
330    /// check if the distribution satisfies other required distribution
331    pub fn satisfies(&self, required: &RequiredDist) -> bool {
332        match self {
333            RequiredDist::Any => matches!(required, RequiredDist::Any),
334            RequiredDist::AnyShard => {
335                matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
336            }
337            RequiredDist::ShardByKey(key) => match required {
338                RequiredDist::Any | RequiredDist::AnyShard => true,
339                RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
340                _ => false,
341            },
342            RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
343        }
344    }
345
346    pub fn batch_enforce(&self, plan: BatchPlanRef, required_order: &Order) -> BatchPlanRef {
347        let dist = self.to_dist();
348        BatchExchange::new(plan, required_order.clone(), dist).into()
349    }
350
351    pub fn stream_enforce(&self, plan: StreamPlanRef) -> StreamPlanRef {
352        let dist = self.to_dist();
353        StreamExchange::new(plan, dist).into()
354    }
355
356    fn to_dist(&self) -> Distribution {
357        match self {
358            // all the distribution satisfy the Any, and the function can be only called by
359            // `enforce_if_not_satisfies`
360            RequiredDist::Any => unreachable!(),
361            // TODO: add round robin distributed type
362            RequiredDist::AnyShard => todo!(),
363            RequiredDist::ShardByKey(required_keys) => {
364                Distribution::HashShard(required_keys.ones().collect())
365            }
366            RequiredDist::PhysicalDist(dist) => dist.clone(),
367        }
368    }
369}
370
371impl StreamPlanRef {
372    /// Eliminate `SomeShard` distribution by using the stream key as the distribution key to
373    /// enforce the current plan to have a known distribution key.
374    pub fn enforce_concrete_distribution(self) -> Self {
375        match self.distribution() {
376            Distribution::SomeShard => {
377                RequiredDist::shard_by_key(self.schema().len(), self.expect_stream_key())
378                    .stream_enforce(self)
379            }
380            _ => self,
381        }
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::{Distribution, RequiredDist};
388
389    #[test]
390    fn hash_shard_satisfy() {
391        let d1 = Distribution::HashShard(vec![0, 1]);
392        let d2 = Distribution::HashShard(vec![1, 0]);
393        let d3 = Distribution::HashShard(vec![0]);
394        let d4 = Distribution::HashShard(vec![1]);
395
396        let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
397        let r3 = RequiredDist::shard_by_key(2, &[0]);
398        let r4 = RequiredDist::shard_by_key(2, &[1]);
399        assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
400        assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
401        assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
402        assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
403
404        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
405        assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
406        assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
407
408        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
409        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
410        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
411        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
412
413        assert!(d1.satisfies(&r1));
414        assert!(d2.satisfies(&r1));
415        assert!(d3.satisfies(&r1));
416        assert!(d4.satisfies(&r1));
417
418        assert!(!d1.satisfies(&r3));
419        assert!(!d2.satisfies(&r3));
420        assert!(d3.satisfies(&r3));
421        assert!(!d4.satisfies(&r3));
422
423        assert!(!d1.satisfies(&r4));
424        assert!(!d2.satisfies(&r4));
425        assert!(!d3.satisfies(&r4));
426        assert!(d4.satisfies(&r4));
427
428        assert!(r3.satisfies(&r1));
429        assert!(r4.satisfies(&r1));
430        assert!(!r1.satisfies(&r3));
431        assert!(!r1.satisfies(&r4));
432        assert!(!r3.satisfies(&r4));
433        assert!(!r4.satisfies(&r3));
434    }
435}