risingwave_frontend/optimizer/property/
distribution.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//!   "A -> B" represent A satisfies B
16//!                                 x
17//!  only as a required property    x  can used as both required
18//!                                 x  and provided property
19//!                                 x
20//!            ┌───┐                x┌──────┐
21//!            │Any◄─────────────────┤single│
22//!            └─▲─┘                x└──────┘
23//!              │                  x
24//!              │                  x
25//!              │                  x
26//!          ┌───┴────┐             x┌──────────┐
27//!          │AnyShard◄──────────────┤SomeShard │
28//!          └───▲────┘             x└──────────┘
29//!              │                  x
30//!          ┌───┴───────────┐      x┌──────────────┐ ┌──────────────┐
31//!          │ShardByKey(a,b)◄───┬───┤HashShard(a,b)│ │HashShard(b,a)│
32//!          └───▲──▲────────┘   │  x└──────────────┘ └┬─────────────┘
33//!              │  │            │  x                  │
34//!              │  │            └─────────────────────┘
35//!              │  │               x
36//!              │ ┌┴────────────┐  x┌────────────┐
37//!              │ │ShardByKey(a)◄───┤HashShard(a)│
38//!              │ └─────────────┘  x└────────────┘
39//!              │                  x
40//!             ┌┴────────────┐     x┌────────────┐
41//!             │ShardByKey(b)◄──────┤HashShard(b)│
42//!             └─────────────┘     x└────────────┘
43//!                                 x
44//!                                 x
45use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57    ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66/// the distribution property provided by a operator.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69    /// There is only one partition. All records are placed on it.
70    ///
71    /// Note: singleton will not be enforced automatically.
72    /// It's set in `crate::stream_fragmenter::build_fragment`,
73    /// by setting `requires_singleton` manually.
74    Single,
75    /// Records are sharded into partitions, and satisfy the `AnyShard` but without any guarantee
76    /// about their placement rules.
77    SomeShard,
78    /// Records are sharded into partitions based on the hash value of some columns, which means
79    /// the records with the same hash values must be on the same partition.
80    /// `usize` is the index of column used as the distribution key.
81    HashShard(Vec<usize>),
82    /// A special kind of provided distribution which is almost the same as
83    /// [`Distribution::HashShard`], but may have different vnode mapping.
84    ///
85    /// It exists because the upstream MV can be scaled independently. So we use
86    /// `UpstreamHashShard` to **force an exchange to be inserted**.
87    ///
88    /// Alternatively, [`Distribution::SomeShard`] can also be used to insert an exchange, but
89    /// `UpstreamHashShard` contains distribution keys, which might be useful in some cases, e.g.,
90    /// two-phase Agg. It also satisfies [`RequiredDist::ShardByKey`].
91    ///
92    /// `TableId` is used to represent the data distribution(`vnode_mapping`) of this
93    /// `UpstreamHashShard`. The scheduler can fetch `TableId`'s corresponding `vnode_mapping` to do
94    /// shuffle.
95    UpstreamHashShard(Vec<usize>, TableId),
96    /// Records are available on all downstream shards.
97    Broadcast,
98}
99
100/// the distribution property requirement.
101#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103    /// with any distribution
104    Any,
105    /// records are shard on partitions, which means every record should belong to a partition
106    AnyShard,
107    /// records are shard on partitions based on some keys(order-irrelevance, ShardByKey({a,b}) is
108    /// equivalent with ShardByKey({b,a})), which means the records with same keys must be on
109    /// the same partition, as required property only.
110    ShardByKey(FixedBitSet),
111    /// must be the same with the physical distribution
112    PhysicalDist(Distribution),
113}
114
115impl Distribution {
116    pub fn to_prost(
117        &self,
118        output_count: u32,
119        catalog_reader: &CatalogReader,
120        worker_node_manager: &WorkerNodeSelector,
121    ) -> Result<ExchangeInfo> {
122        let exchange_info = ExchangeInfo {
123            mode: match self {
124                Distribution::Single => DistributionMode::Single,
125                Distribution::HashShard(_) => DistributionMode::Hash,
126                // TODO: add round robin DistributionMode
127                Distribution::SomeShard => DistributionMode::Single,
128                Distribution::Broadcast => DistributionMode::Broadcast,
129                Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
130            } as i32,
131            distribution: match self {
132                Distribution::Single => None,
133                Distribution::HashShard(key) => {
134                    assert!(
135                        !key.is_empty(),
136                        "hash key should not be empty, use `Single` instead"
137                    );
138                    Some(PbDistribution::HashInfo(HashInfo {
139                        output_count,
140                        key: key.iter().map(|num| *num as u32).collect(),
141                    }))
142                }
143                // TODO: add round robin distribution
144                Distribution::SomeShard => None,
145                Distribution::Broadcast => None,
146                Distribution::UpstreamHashShard(key, table_id) => {
147                    assert!(
148                        !key.is_empty(),
149                        "hash key should not be empty, use `Single` instead"
150                    );
151
152                    let vnode_mapping = worker_node_manager
153                        .fragment_mapping(Self::get_fragment_id(catalog_reader, table_id)?)?;
154
155                    let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
156                        .iter_unique()
157                        .enumerate()
158                        .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
159                        .collect();
160
161                    Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
162                        vmap: vnode_mapping
163                            .iter()
164                            .map(|id| worker_slot_to_id_map[&id])
165                            .collect_vec(),
166                        key: key.iter().map(|num| *num as u32).collect(),
167                    }))
168                }
169            },
170        };
171        Ok(exchange_info)
172    }
173
174    /// check if the distribution satisfies other required distribution
175    pub fn satisfies(&self, required: &RequiredDist) -> bool {
176        match required {
177            RequiredDist::Any => true,
178            RequiredDist::AnyShard => {
179                matches!(
180                    self,
181                    Distribution::SomeShard
182                        | Distribution::HashShard(_)
183                        | Distribution::UpstreamHashShard(_, _)
184                        | Distribution::Broadcast
185                )
186            }
187            RequiredDist::ShardByKey(required_key) => match self {
188                Distribution::HashShard(hash_key)
189                | Distribution::UpstreamHashShard(hash_key, _) => {
190                    hash_key.iter().all(|idx| required_key.contains(*idx))
191                }
192                _ => false,
193            },
194            RequiredDist::PhysicalDist(other) => self == other,
195        }
196    }
197
198    /// Get distribution column indices. After optimization, only `HashShard` and `Single` are
199    /// valid.
200    pub fn dist_column_indices(&self) -> &[usize] {
201        match self {
202            Distribution::Single | Distribution::SomeShard | Distribution::Broadcast => {
203                Default::default()
204            }
205            Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => dists,
206        }
207    }
208
209    #[inline(always)]
210    fn get_fragment_id(catalog_reader: &CatalogReader, table_id: &TableId) -> Result<FragmentId> {
211        catalog_reader
212            .read_guard()
213            .get_any_table_by_id(table_id)
214            .map(|table| table.fragment_id)
215            .map_err(Into::into)
216    }
217}
218
219impl fmt::Display for Distribution {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        f.write_str("[")?;
222        match self {
223            Self::Single => f.write_str("Single")?,
224            Self::SomeShard => f.write_str("SomeShard")?,
225            Self::Broadcast => f.write_str("Broadcast")?,
226            Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
227                for key in vec {
228                    std::fmt::Debug::fmt(&key, f)?;
229                }
230            }
231        }
232        f.write_str("]")
233    }
234}
235
236pub struct DistributionDisplay<'a> {
237    pub distribution: &'a Distribution,
238    pub input_schema: &'a Schema,
239}
240
241impl DistributionDisplay<'_> {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        let that = self.distribution;
244        match that {
245            Distribution::Single => f.write_str("Single"),
246            Distribution::SomeShard => f.write_str("SomeShard"),
247            Distribution::Broadcast => f.write_str("Broadcast"),
248            Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
249                if let Distribution::HashShard(_) = that {
250                    f.write_str("HashShard(")?;
251                } else {
252                    f.write_str("UpstreamHashShard(")?;
253                }
254                for (pos, key) in vec.iter().copied().with_position() {
255                    std::fmt::Debug::fmt(
256                        &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
257                        f,
258                    )?;
259                    match pos {
260                        itertools::Position::First | itertools::Position::Middle => {
261                            f.write_str(", ")?;
262                        }
263                        _ => {}
264                    }
265                }
266                f.write_str(")")
267            }
268        }
269    }
270}
271
272impl fmt::Debug for DistributionDisplay<'_> {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        self.fmt(f)
275    }
276}
277
278impl fmt::Display for DistributionDisplay<'_> {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        self.fmt(f)
281    }
282}
283
284impl RequiredDist {
285    pub fn single() -> Self {
286        Self::PhysicalDist(Distribution::Single)
287    }
288
289    pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
290        let mut cols = FixedBitSet::with_capacity(tot_col_num);
291        for i in key {
292            cols.insert(*i);
293        }
294        assert!(!cols.is_clear());
295        Self::ShardByKey(cols)
296    }
297
298    pub fn hash_shard(key: &[usize]) -> Self {
299        assert!(!key.is_empty());
300        Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
301    }
302
303    pub fn enforce_if_not_satisfies(
304        &self,
305        mut plan: PlanRef,
306        required_order: &Order,
307    ) -> Result<PlanRef> {
308        if let Convention::Batch = plan.convention() {
309            plan = required_order.enforce_if_not_satisfies(plan)?;
310        }
311        if !plan.distribution().satisfies(self) {
312            Ok(self.enforce(plan, required_order))
313        } else {
314            Ok(plan)
315        }
316    }
317
318    pub fn no_shuffle(plan: PlanRef) -> PlanRef {
319        match plan.convention() {
320            Convention::Stream => StreamExchange::new_no_shuffle(plan).into(),
321            Convention::Logical | Convention::Batch => unreachable!(),
322        }
323    }
324
325    /// check if the distribution satisfies other required distribution
326    pub fn satisfies(&self, required: &RequiredDist) -> bool {
327        match self {
328            RequiredDist::Any => matches!(required, RequiredDist::Any),
329            RequiredDist::AnyShard => {
330                matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
331            }
332            RequiredDist::ShardByKey(key) => match required {
333                RequiredDist::Any | RequiredDist::AnyShard => true,
334                RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
335                _ => false,
336            },
337            RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
338        }
339    }
340
341    pub fn enforce(&self, plan: PlanRef, required_order: &Order) -> PlanRef {
342        let dist = self.to_dist();
343        match plan.convention() {
344            Convention::Batch => BatchExchange::new(plan, required_order.clone(), dist).into(),
345            Convention::Stream => StreamExchange::new(plan, dist).into(),
346            _ => unreachable!(),
347        }
348    }
349
350    fn to_dist(&self) -> Distribution {
351        match self {
352            // all the distribution satisfy the Any, and the function can be only called by
353            // `enforce_if_not_satisfies`
354            RequiredDist::Any => unreachable!(),
355            // TODO: add round robin distributed type
356            RequiredDist::AnyShard => todo!(),
357            RequiredDist::ShardByKey(required_keys) => {
358                Distribution::HashShard(required_keys.ones().collect())
359            }
360            RequiredDist::PhysicalDist(dist) => dist.clone(),
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::{Distribution, RequiredDist};
368
369    #[test]
370    fn hash_shard_satisfy() {
371        let d1 = Distribution::HashShard(vec![0, 1]);
372        let d2 = Distribution::HashShard(vec![1, 0]);
373        let d3 = Distribution::HashShard(vec![0]);
374        let d4 = Distribution::HashShard(vec![1]);
375
376        let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
377        let r3 = RequiredDist::shard_by_key(2, &[0]);
378        let r4 = RequiredDist::shard_by_key(2, &[1]);
379        assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
380        assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
381        assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
382        assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
383
384        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
385        assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
386        assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
387
388        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
389        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
390        assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
391        assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
392
393        assert!(d1.satisfies(&r1));
394        assert!(d2.satisfies(&r1));
395        assert!(d3.satisfies(&r1));
396        assert!(d4.satisfies(&r1));
397
398        assert!(!d1.satisfies(&r3));
399        assert!(!d2.satisfies(&r3));
400        assert!(d3.satisfies(&r3));
401        assert!(!d4.satisfies(&r3));
402
403        assert!(!d1.satisfies(&r4));
404        assert!(!d2.satisfies(&r4));
405        assert!(!d3.satisfies(&r4));
406        assert!(d4.satisfies(&r4));
407
408        assert!(r3.satisfies(&r1));
409        assert!(r4.satisfies(&r1));
410        assert!(!r1.satisfies(&r3));
411        assert!(!r1.satisfies(&r4));
412        assert!(!r3.satisfies(&r4));
413        assert!(!r4.satisfies(&r3));
414    }
415}