risingwave_frontend/optimizer/property/
distribution.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//!   "A -> B" represent A satisfies B
16//!                                 x
17//!  only as a required property    x  can used as both required
18//!                                 x  and provided property
19//!                                 x
20//!            ┌───┐                x┌──────┐
21//!            │Any◄─────────────────┤single│
22//!            └─▲─┘                x└──────┘
23//!              │                  x
24//!              │                  x
25//!              │                  x
26//!          ┌───┴────┐             x┌──────────┐
27//!          │AnyShard◄──────────────┤SomeShard │
28//!          └───▲────┘             x└──────────┘
29//!              │                  x
30//!          ┌───┴───────────┐      x┌──────────────┐ ┌──────────────┐
31//!          │ShardByKey(a,b)◄───┬───┤HashShard(a,b)│ │HashShard(b,a)│
32//!          └───▲──▲────────┘   │  x└──────────────┘ └┬─────────────┘
33//!              │  │            │  x                  │
34//!              │  │            └─────────────────────┘
35//!              │  │               x
36//!              │ ┌┴────────────┐  x┌────────────┐
37//!              │ │ShardByKey(a)◄───┤HashShard(a)│
38//!              │ └─────────────┘  x└────────────┘
39//!              │                  x
40//!             ┌┴────────────┐     x┌────────────┐
41//!             │ShardByKey(b)◄──────┤HashShard(b)│
42//!             └─────────────┘     x└────────────┘
43//!                                 x
44//!                                 x
45use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57    ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66/// the distribution property provided by a operator.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69    /// There is only one partition. All records are placed on it.
70    ///
71    /// Note: singleton will not be enforced automatically.
72    /// It's set in `crate::stream_fragmenter::build_fragment`,
73    /// by setting `requires_singleton` manually.
74    Single,
75    /// Records are sharded into partitions, and satisfy the `AnyShard` but without any guarantee
76    /// about their placement rules.
77    SomeShard,
78    /// Records are sharded into partitions based on the hash value of some columns, which means
79    /// the records with the same hash values must be on the same partition.
80    /// `usize` is the index of column used as the distribution key.
81    HashShard(Vec<usize>),
82    /// A special kind of provided distribution which is almost the same as
83    /// [`Distribution::HashShard`], but may have different vnode mapping.
84    ///
85    /// It exists because the upstream MV can be scaled independently. So we use
86    /// `UpstreamHashShard` to **force an exchange to be inserted**.
87    ///
88    /// Alternatively, [`Distribution::SomeShard`] can also be used to insert an exchange, but
89    /// `UpstreamHashShard` contains distribution keys, which might be useful in some cases, e.g.,
90    /// two-phase Agg. It also satisfies [`RequiredDist::ShardByKey`].
91    ///
92    /// `TableId` is used to represent the data distribution(`vnode_mapping`) of this
93    /// `UpstreamHashShard`. The scheduler can fetch `TableId`'s corresponding `vnode_mapping` to do
94    /// shuffle.
95    UpstreamHashShard(Vec<usize>, TableId),
96    /// Records are available on all downstream shards.
97    Broadcast,
98}
99
100/// the distribution property requirement.
101#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103    /// with any distribution
104    Any,
105    /// records are shard on partitions, which means every record should belong to a partition
106    AnyShard,
107    /// records are shard on partitions based on some keys(order-irrelevance, ShardByKey({a,b}) is
108    /// equivalent with ShardByKey({b,a})), which means the records with same keys must be on
109    /// the same partition, as required property only. Any distribution sharded by a subset of this
110    /// key set satisfies the requirement.
111    ShardByKey(FixedBitSet),
112    /// records are shard on partitions based on an exact set of keys (order-irrelevance).
113    /// Only distribution sharded by the same key set satisfies this requirement.
114    ShardByExactKey(FixedBitSet),
115    /// must be the same with the physical distribution
116    PhysicalDist(Distribution),
117}
118
119impl Distribution {
120    pub fn to_prost(
121        &self,
122        output_count: u32,
123        catalog_reader: &CatalogReader,
124        worker_node_manager: &WorkerNodeSelector,
125    ) -> Result<ExchangeInfo> {
126        let exchange_info = ExchangeInfo {
127            mode: match self {
128                Distribution::Single => DistributionMode::Single,
129                Distribution::HashShard(_) => DistributionMode::Hash,
130                // TODO: add round robin DistributionMode
131                Distribution::SomeShard => DistributionMode::Single,
132                Distribution::Broadcast => DistributionMode::Broadcast,
133                Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
134            } as i32,
135            distribution: match self {
136                Distribution::Single => None,
137                Distribution::HashShard(key) => {
138                    assert!(
139                        !key.is_empty(),
140                        "hash key should not be empty, use `Single` instead"
141                    );
142                    Some(PbDistribution::HashInfo(HashInfo {
143                        output_count,
144                        key: key.iter().map(|num| *num as u32).collect(),
145                    }))
146                }
147                // TODO: add round robin distribution
148                Distribution::SomeShard => None,
149                Distribution::Broadcast => None,
150                Distribution::UpstreamHashShard(key, table_id) => {
151                    assert!(
152                        !key.is_empty(),
153                        "hash key should not be empty, use `Single` instead"
154                    );
155
156                    let vnode_mapping = worker_node_manager
157                        .fragment_mapping(Self::get_fragment_id(catalog_reader, *table_id)?)?;
158
159                    let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
160                        .iter_unique()
161                        .enumerate()
162                        .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
163                        .collect();
164
165                    Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
166                        vmap: vnode_mapping
167                            .iter()
168                            .map(|id| worker_slot_to_id_map[&id])
169                            .collect_vec(),
170                        key: key.iter().map(|num| *num as u32).collect(),
171                    }))
172                }
173            },
174        };
175        Ok(exchange_info)
176    }
177
178    /// check if the distribution satisfies other required distribution
179    pub fn satisfies(&self, required: &RequiredDist) -> bool {
180        match required {
181            RequiredDist::Any => true,
182            RequiredDist::AnyShard => {
183                matches!(
184                    self,
185                    Distribution::SomeShard
186                        | Distribution::HashShard(_)
187                        | Distribution::UpstreamHashShard(_, _)
188                        | Distribution::Broadcast
189                )
190            }
191            RequiredDist::ShardByKey(required_key) => match self {
192                Distribution::HashShard(hash_key)
193                | Distribution::UpstreamHashShard(hash_key, _) => {
194                    hash_key.iter().all(|idx| required_key.contains(*idx))
195                }
196                _ => false,
197            },
198            RequiredDist::ShardByExactKey(required_key) => match self {
199                Distribution::HashShard(hash_key)
200                | Distribution::UpstreamHashShard(hash_key, _) => {
201                    hash_key.len() == required_key.count_ones(..)
202                        && hash_key.iter().all(|idx| required_key.contains(*idx))
203                }
204                _ => false,
205            },
206            RequiredDist::PhysicalDist(other) => self == other,
207        }
208    }
209
210    /// Get distribution column indices. Panics if the distribution is `SomeShard` or `Broadcast`.
211    pub fn dist_column_indices(&self) -> &[usize] {
212        self.dist_column_indices_opt()
213            .unwrap_or_else(|| panic!("cannot obtain distribution columns for {self:?}"))
214    }
215
216    /// Get distribution column indices. Returns `None` if the distribution is `SomeShard` or `Broadcast`.
217    pub fn dist_column_indices_opt(&self) -> Option<&[usize]> {
218        match self {
219            Distribution::Single => Some(&[]),
220            Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
221                Some(dists)
222            }
223            Distribution::SomeShard | Distribution::Broadcast => None,
224        }
225    }
226
227    #[inline(always)]
228    fn get_fragment_id(catalog_reader: &CatalogReader, table_id: TableId) -> Result<FragmentId> {
229        catalog_reader
230            .read_guard()
231            .get_any_table_by_id(table_id)
232            .map(|table| table.fragment_id)
233            .map_err(Into::into)
234    }
235}
236
237impl fmt::Display for Distribution {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        f.write_str("[")?;
240        match self {
241            Self::Single => f.write_str("Single")?,
242            Self::SomeShard => f.write_str("SomeShard")?,
243            Self::Broadcast => f.write_str("Broadcast")?,
244            Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
245                for key in vec {
246                    std::fmt::Debug::fmt(&key, f)?;
247                }
248            }
249        }
250        f.write_str("]")
251    }
252}
253
254pub struct DistributionDisplay<'a> {
255    pub distribution: &'a Distribution,
256    pub input_schema: &'a Schema,
257}
258
259impl DistributionDisplay<'_> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        let that = self.distribution;
262        match that {
263            Distribution::Single => f.write_str("Single"),
264            Distribution::SomeShard => f.write_str("SomeShard"),
265            Distribution::Broadcast => f.write_str("Broadcast"),
266            Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
267                if let Distribution::HashShard(_) = that {
268                    f.write_str("HashShard(")?;
269                } else {
270                    f.write_str("UpstreamHashShard(")?;
271                }
272                for (pos, key) in vec.iter().copied().with_position() {
273                    std::fmt::Debug::fmt(
274                        &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
275                        f,
276                    )?;
277                    match pos {
278                        itertools::Position::First | itertools::Position::Middle => {
279                            f.write_str(", ")?;
280                        }
281                        _ => {}
282                    }
283                }
284                f.write_str(")")
285            }
286        }
287    }
288}
289
290impl fmt::Debug for DistributionDisplay<'_> {
291    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292        self.fmt(f)
293    }
294}
295
296impl fmt::Display for DistributionDisplay<'_> {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        self.fmt(f)
299    }
300}
301
302impl RequiredDist {
303    pub fn single() -> Self {
304        Self::PhysicalDist(Distribution::Single)
305    }
306
307    pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
308        let mut cols = FixedBitSet::with_capacity(tot_col_num);
309        for i in key {
310            cols.insert(*i);
311        }
312        assert!(!cols.is_clear());
313        Self::ShardByKey(cols)
314    }
315
316    pub fn shard_by_exact_key(tot_col_num: usize, key: &[usize]) -> Self {
317        let mut cols = FixedBitSet::with_capacity(tot_col_num);
318        for i in key {
319            cols.insert(*i);
320        }
321        assert!(!cols.is_clear());
322        Self::ShardByExactKey(cols)
323    }
324
325    pub fn hash_shard(key: &[usize]) -> Self {
326        assert!(!key.is_empty());
327        Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
328    }
329
330    pub fn batch_enforce_if_not_satisfies(
331        &self,
332        mut plan: BatchPlanRef,
333        required_order: &Order,
334    ) -> Result<BatchPlanRef> {
335        plan = required_order.enforce_if_not_satisfies(plan)?;
336        if !plan.distribution().satisfies(self) {
337            Ok(self.batch_enforce(plan, required_order))
338        } else {
339            Ok(plan)
340        }
341    }
342
343    pub fn streaming_enforce_if_not_satisfies(&self, plan: StreamPlanRef) -> Result<StreamPlanRef> {
344        if !plan.distribution().satisfies(self) {
345            Ok(self.stream_enforce(plan))
346        } else {
347            Ok(plan)
348        }
349    }
350
351    pub fn no_shuffle(plan: StreamPlanRef) -> StreamPlanRef {
352        StreamExchange::new_no_shuffle(plan).into()
353    }
354
355    /// check if the distribution satisfies other required distribution
356    pub fn satisfies(&self, required: &RequiredDist) -> bool {
357        match self {
358            RequiredDist::Any => matches!(required, RequiredDist::Any),
359            RequiredDist::AnyShard => {
360                matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
361            }
362            RequiredDist::ShardByKey(key) => match required {
363                RequiredDist::Any | RequiredDist::AnyShard => true,
364                RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
365                RequiredDist::ShardByExactKey(required_key) => {
366                    key == required_key && key.count_ones(..) == 1
367                }
368                _ => false,
369            },
370            RequiredDist::ShardByExactKey(key) => match required {
371                RequiredDist::Any | RequiredDist::AnyShard => true,
372                RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
373                RequiredDist::ShardByExactKey(required_key) => key == required_key,
374                _ => false,
375            },
376            RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
377        }
378    }
379
380    pub fn batch_enforce(&self, plan: BatchPlanRef, required_order: &Order) -> BatchPlanRef {
381        let dist = self.to_dist();
382        BatchExchange::new(plan, required_order.clone(), dist).into()
383    }
384
385    pub fn stream_enforce(&self, plan: StreamPlanRef) -> StreamPlanRef {
386        let dist = self.to_dist();
387        StreamExchange::new(plan, dist).into()
388    }
389
390    fn to_dist(&self) -> Distribution {
391        match self {
392            // all the distribution satisfy the Any, and the function can be only called by
393            // `enforce_if_not_satisfies`
394            RequiredDist::Any => unreachable!(),
395            // TODO: add round robin distributed type
396            RequiredDist::AnyShard => todo!(),
397            RequiredDist::ShardByKey(required_keys) => {
398                Distribution::HashShard(required_keys.ones().collect())
399            }
400            RequiredDist::ShardByExactKey(required_keys) => {
401                Distribution::HashShard(required_keys.ones().collect())
402            }
403            RequiredDist::PhysicalDist(dist) => dist.clone(),
404        }
405    }
406}
407
408impl StreamPlanRef {
409    /// Eliminate `SomeShard` distribution by using the stream key as the distribution key to
410    /// enforce the current plan to have a known distribution key.
411    pub fn enforce_concrete_distribution(self) -> Self {
412        match self.distribution() {
413            Distribution::SomeShard => {
414                RequiredDist::shard_by_key(self.schema().len(), self.expect_stream_key())
415                    .stream_enforce(self)
416            }
417            _ => self,
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::{Distribution, RequiredDist};
425
426    #[test]
427    fn hash_shard_satisfy() {
428        let d1 = Distribution::HashShard(vec![0, 1]);
429        let d2 = Distribution::HashShard(vec![1, 0]);
430        let d3 = Distribution::HashShard(vec![0]);
431        let d4 = Distribution::HashShard(vec![1]);
432
433        let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
434        let r3 = RequiredDist::shard_by_key(2, &[0]);
435        let r4 = RequiredDist::shard_by_key(2, &[1]);
436        let r_exact = RequiredDist::shard_by_exact_key(2, &[0, 1]);
437        let r_exact_single = RequiredDist::shard_by_exact_key(2, &[0]);
438        assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
439        assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
440        assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
441        assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
442
443        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
444        assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
445        assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
446
447        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
448        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
449        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
450        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
451
452        assert!(d1.satisfies(&r1));
453        assert!(d2.satisfies(&r1));
454        assert!(d3.satisfies(&r1));
455        assert!(d4.satisfies(&r1));
456
457        assert!(!d1.satisfies(&r3));
458        assert!(!d2.satisfies(&r3));
459        assert!(d3.satisfies(&r3));
460        assert!(!d4.satisfies(&r3));
461
462        assert!(!d1.satisfies(&r4));
463        assert!(!d2.satisfies(&r4));
464        assert!(!d3.satisfies(&r4));
465        assert!(d4.satisfies(&r4));
466
467        assert!(d1.satisfies(&r_exact));
468        assert!(d2.satisfies(&r_exact));
469        assert!(!d3.satisfies(&r_exact));
470        assert!(!d4.satisfies(&r_exact));
471
472        assert!(r3.satisfies(&r1));
473        assert!(r4.satisfies(&r1));
474        assert!(!r1.satisfies(&r3));
475        assert!(!r1.satisfies(&r4));
476        assert!(!r3.satisfies(&r4));
477        assert!(!r4.satisfies(&r3));
478
479        assert!(r_exact.satisfies(&r1));
480        assert!(!r1.satisfies(&r_exact));
481        assert!(!r3.satisfies(&r_exact));
482        assert!(!r_exact.satisfies(&r3));
483
484        assert!(r3.satisfies(&r_exact_single));
485        assert!(r_exact_single.satisfies(&r3));
486    }
487}