risingwave_frontend/optimizer/property/
distribution.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//!   "A -> B" represent A satisfies B
16//!                                 x
17//!  only as a required property    x  can used as both required
18//!                                 x  and provided property
19//!                                 x
20//!            ┌───┐                x┌──────┐
21//!            │Any◄─────────────────┤single│
22//!            └─▲─┘                x└──────┘
23//!              │                  x
24//!              │                  x
25//!              │                  x
26//!          ┌───┴────┐             x┌──────────┐
27//!          │AnyShard◄──────────────┤SomeShard │
28//!          └───▲────┘             x└──────────┘
29//!              │                  x
30//!          ┌───┴───────────┐      x┌──────────────┐ ┌──────────────┐
31//!          │ShardByKey(a,b)◄───┬───┤HashShard(a,b)│ │HashShard(b,a)│
32//!          └───▲──▲────────┘   │  x└──────────────┘ └┬─────────────┘
33//!              │  │            │  x                  │
34//!              │  │            └─────────────────────┘
35//!              │  │               x
36//!              │ ┌┴────────────┐  x┌────────────┐
37//!              │ │ShardByKey(a)◄───┤HashShard(a)│
38//!              │ └─────────────┘  x└────────────┘
39//!              │                  x
40//!             ┌┴────────────┐     x┌────────────┐
41//!             │ShardByKey(b)◄──────┤HashShard(b)│
42//!             └─────────────┘     x└────────────┘
43//!                                 x
44//!                                 x
45use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57    ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66/// the distribution property provided by a operator.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69    /// There is only one partition. All records are placed on it.
70    ///
71    /// Note: singleton will not be enforced automatically.
72    /// It's set in `crate::stream_fragmenter::build_fragment`,
73    /// by setting `requires_singleton` manually.
74    Single,
75    /// Records are sharded into partitions, and satisfy the `AnyShard` but without any guarantee
76    /// about their placement rules.
77    SomeShard,
78    /// Records are sharded into partitions based on the hash value of some columns, which means
79    /// the records with the same hash values must be on the same partition.
80    /// `usize` is the index of column used as the distribution key.
81    HashShard(Vec<usize>),
82    /// A special kind of provided distribution which is almost the same as
83    /// [`Distribution::HashShard`], but may have different vnode mapping.
84    ///
85    /// It exists because the upstream MV can be scaled independently. So we use
86    /// `UpstreamHashShard` to **force an exchange to be inserted**.
87    ///
88    /// Alternatively, [`Distribution::SomeShard`] can also be used to insert an exchange, but
89    /// `UpstreamHashShard` contains distribution keys, which might be useful in some cases, e.g.,
90    /// two-phase Agg. It also satisfies [`RequiredDist::ShardByKey`].
91    ///
92    /// `TableId` is used to represent the data distribution(`vnode_mapping`) of this
93    /// `UpstreamHashShard`. The scheduler can fetch `TableId`'s corresponding `vnode_mapping` to do
94    /// shuffle.
95    UpstreamHashShard(Vec<usize>, TableId),
96    /// Records are available on all downstream shards.
97    Broadcast,
98}
99
100/// the distribution property requirement.
101#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103    /// with any distribution
104    Any,
105    /// records are shard on partitions, which means every record should belong to a partition
106    AnyShard,
107    /// records are shard on partitions based on some keys(order-irrelevance, ShardByKey({a,b}) is
108    /// equivalent with ShardByKey({b,a})), which means the records with same keys must be on
109    /// the same partition, as required property only.
110    ShardByKey(FixedBitSet),
111    /// must be the same with the physical distribution
112    PhysicalDist(Distribution),
113}
114
115impl Distribution {
116    pub fn to_prost(
117        &self,
118        output_count: u32,
119        catalog_reader: &CatalogReader,
120        worker_node_manager: &WorkerNodeSelector,
121    ) -> Result<ExchangeInfo> {
122        let exchange_info = ExchangeInfo {
123            mode: match self {
124                Distribution::Single => DistributionMode::Single,
125                Distribution::HashShard(_) => DistributionMode::Hash,
126                // TODO: add round robin DistributionMode
127                Distribution::SomeShard => DistributionMode::Single,
128                Distribution::Broadcast => DistributionMode::Broadcast,
129                Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
130            } as i32,
131            distribution: match self {
132                Distribution::Single => None,
133                Distribution::HashShard(key) => {
134                    assert!(
135                        !key.is_empty(),
136                        "hash key should not be empty, use `Single` instead"
137                    );
138                    Some(PbDistribution::HashInfo(HashInfo {
139                        output_count,
140                        key: key.iter().map(|num| *num as u32).collect(),
141                    }))
142                }
143                // TODO: add round robin distribution
144                Distribution::SomeShard => None,
145                Distribution::Broadcast => None,
146                Distribution::UpstreamHashShard(key, table_id) => {
147                    assert!(
148                        !key.is_empty(),
149                        "hash key should not be empty, use `Single` instead"
150                    );
151
152                    let vnode_mapping = worker_node_manager
153                        .fragment_mapping(Self::get_fragment_id(catalog_reader, table_id)?)?;
154
155                    let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
156                        .iter_unique()
157                        .enumerate()
158                        .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
159                        .collect();
160
161                    Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
162                        vmap: vnode_mapping
163                            .iter()
164                            .map(|id| worker_slot_to_id_map[&id])
165                            .collect_vec(),
166                        key: key.iter().map(|num| *num as u32).collect(),
167                    }))
168                }
169            },
170        };
171        Ok(exchange_info)
172    }
173
174    /// check if the distribution satisfies other required distribution
175    pub fn satisfies(&self, required: &RequiredDist) -> bool {
176        match required {
177            RequiredDist::Any => true,
178            RequiredDist::AnyShard => {
179                matches!(
180                    self,
181                    Distribution::SomeShard
182                        | Distribution::HashShard(_)
183                        | Distribution::UpstreamHashShard(_, _)
184                        | Distribution::Broadcast
185                )
186            }
187            RequiredDist::ShardByKey(required_key) => match self {
188                Distribution::HashShard(hash_key)
189                | Distribution::UpstreamHashShard(hash_key, _) => {
190                    hash_key.iter().all(|idx| required_key.contains(*idx))
191                }
192                _ => false,
193            },
194            RequiredDist::PhysicalDist(other) => self == other,
195        }
196    }
197
198    /// Get distribution column indices. Panics if the distribution is `SomeShard` or `Broadcast`.
199    pub fn dist_column_indices(&self) -> &[usize] {
200        self.dist_column_indices_opt()
201            .unwrap_or_else(|| panic!("cannot obtain distribution columns for {self:?}"))
202    }
203
204    /// Get distribution column indices. Returns `None` if the distribution is `SomeShard` or `Broadcast`.
205    pub fn dist_column_indices_opt(&self) -> Option<&[usize]> {
206        match self {
207            Distribution::Single => Some(&[]),
208            Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
209                Some(dists)
210            }
211            Distribution::SomeShard | Distribution::Broadcast => None,
212        }
213    }
214
215    #[inline(always)]
216    fn get_fragment_id(catalog_reader: &CatalogReader, table_id: &TableId) -> Result<FragmentId> {
217        catalog_reader
218            .read_guard()
219            .get_any_table_by_id(table_id)
220            .map(|table| table.fragment_id)
221            .map_err(Into::into)
222    }
223}
224
225impl fmt::Display for Distribution {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        f.write_str("[")?;
228        match self {
229            Self::Single => f.write_str("Single")?,
230            Self::SomeShard => f.write_str("SomeShard")?,
231            Self::Broadcast => f.write_str("Broadcast")?,
232            Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
233                for key in vec {
234                    std::fmt::Debug::fmt(&key, f)?;
235                }
236            }
237        }
238        f.write_str("]")
239    }
240}
241
242pub struct DistributionDisplay<'a> {
243    pub distribution: &'a Distribution,
244    pub input_schema: &'a Schema,
245}
246
247impl DistributionDisplay<'_> {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        let that = self.distribution;
250        match that {
251            Distribution::Single => f.write_str("Single"),
252            Distribution::SomeShard => f.write_str("SomeShard"),
253            Distribution::Broadcast => f.write_str("Broadcast"),
254            Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
255                if let Distribution::HashShard(_) = that {
256                    f.write_str("HashShard(")?;
257                } else {
258                    f.write_str("UpstreamHashShard(")?;
259                }
260                for (pos, key) in vec.iter().copied().with_position() {
261                    std::fmt::Debug::fmt(
262                        &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
263                        f,
264                    )?;
265                    match pos {
266                        itertools::Position::First | itertools::Position::Middle => {
267                            f.write_str(", ")?;
268                        }
269                        _ => {}
270                    }
271                }
272                f.write_str(")")
273            }
274        }
275    }
276}
277
278impl fmt::Debug for DistributionDisplay<'_> {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        self.fmt(f)
281    }
282}
283
284impl fmt::Display for DistributionDisplay<'_> {
285    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286        self.fmt(f)
287    }
288}
289
290impl RequiredDist {
291    pub fn single() -> Self {
292        Self::PhysicalDist(Distribution::Single)
293    }
294
295    pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
296        let mut cols = FixedBitSet::with_capacity(tot_col_num);
297        for i in key {
298            cols.insert(*i);
299        }
300        assert!(!cols.is_clear());
301        Self::ShardByKey(cols)
302    }
303
304    pub fn hash_shard(key: &[usize]) -> Self {
305        assert!(!key.is_empty());
306        Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
307    }
308
309    pub fn batch_enforce_if_not_satisfies(
310        &self,
311        mut plan: BatchPlanRef,
312        required_order: &Order,
313    ) -> Result<BatchPlanRef> {
314        plan = required_order.enforce_if_not_satisfies(plan)?;
315        if !plan.distribution().satisfies(self) {
316            Ok(self.batch_enforce(plan, required_order))
317        } else {
318            Ok(plan)
319        }
320    }
321
322    pub fn streaming_enforce_if_not_satisfies(&self, plan: StreamPlanRef) -> Result<StreamPlanRef> {
323        if !plan.distribution().satisfies(self) {
324            Ok(self.stream_enforce(plan))
325        } else {
326            Ok(plan)
327        }
328    }
329
330    pub fn no_shuffle(plan: StreamPlanRef) -> StreamPlanRef {
331        StreamExchange::new_no_shuffle(plan).into()
332    }
333
334    /// check if the distribution satisfies other required distribution
335    pub fn satisfies(&self, required: &RequiredDist) -> bool {
336        match self {
337            RequiredDist::Any => matches!(required, RequiredDist::Any),
338            RequiredDist::AnyShard => {
339                matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
340            }
341            RequiredDist::ShardByKey(key) => match required {
342                RequiredDist::Any | RequiredDist::AnyShard => true,
343                RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
344                _ => false,
345            },
346            RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
347        }
348    }
349
350    pub fn batch_enforce(&self, plan: BatchPlanRef, required_order: &Order) -> BatchPlanRef {
351        let dist = self.to_dist();
352        BatchExchange::new(plan, required_order.clone(), dist).into()
353    }
354
355    pub fn stream_enforce(&self, plan: StreamPlanRef) -> StreamPlanRef {
356        let dist = self.to_dist();
357        StreamExchange::new(plan, dist).into()
358    }
359
360    fn to_dist(&self) -> Distribution {
361        match self {
362            // all the distribution satisfy the Any, and the function can be only called by
363            // `enforce_if_not_satisfies`
364            RequiredDist::Any => unreachable!(),
365            // TODO: add round robin distributed type
366            RequiredDist::AnyShard => todo!(),
367            RequiredDist::ShardByKey(required_keys) => {
368                Distribution::HashShard(required_keys.ones().collect())
369            }
370            RequiredDist::PhysicalDist(dist) => dist.clone(),
371        }
372    }
373}
374
375impl StreamPlanRef {
376    /// Eliminate `SomeShard` distribution by using the stream key as the distribution key to
377    /// enforce the current plan to have a known distribution key.
378    pub fn enforce_concrete_distribution(self) -> Self {
379        match self.distribution() {
380            Distribution::SomeShard => {
381                RequiredDist::shard_by_key(self.schema().len(), self.expect_stream_key())
382                    .stream_enforce(self)
383            }
384            _ => self,
385        }
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::{Distribution, RequiredDist};
392
393    #[test]
394    fn hash_shard_satisfy() {
395        let d1 = Distribution::HashShard(vec![0, 1]);
396        let d2 = Distribution::HashShard(vec![1, 0]);
397        let d3 = Distribution::HashShard(vec![0]);
398        let d4 = Distribution::HashShard(vec![1]);
399
400        let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
401        let r3 = RequiredDist::shard_by_key(2, &[0]);
402        let r4 = RequiredDist::shard_by_key(2, &[1]);
403        assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
404        assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
405        assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
406        assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
407
408        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
409        assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
410        assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
411
412        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
413        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
414        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
415        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
416
417        assert!(d1.satisfies(&r1));
418        assert!(d2.satisfies(&r1));
419        assert!(d3.satisfies(&r1));
420        assert!(d4.satisfies(&r1));
421
422        assert!(!d1.satisfies(&r3));
423        assert!(!d2.satisfies(&r3));
424        assert!(d3.satisfies(&r3));
425        assert!(!d4.satisfies(&r3));
426
427        assert!(!d1.satisfies(&r4));
428        assert!(!d2.satisfies(&r4));
429        assert!(!d3.satisfies(&r4));
430        assert!(d4.satisfies(&r4));
431
432        assert!(r3.satisfies(&r1));
433        assert!(r4.satisfies(&r1));
434        assert!(!r1.satisfies(&r3));
435        assert!(!r1.satisfies(&r4));
436        assert!(!r3.satisfies(&r4));
437        assert!(!r4.satisfies(&r3));
438    }
439}