risingwave_frontend/optimizer/property/
distribution.rs1use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57 ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69 Single,
75 SomeShard,
78 HashShard(Vec<usize>),
82 UpstreamHashShard(Vec<usize>, TableId),
96 Broadcast,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103 Any,
105 AnyShard,
107 ShardByKey(FixedBitSet),
112 ShardByExactKey(FixedBitSet),
115 PhysicalDist(Distribution),
117}
118
119impl Distribution {
120 pub fn to_prost(
121 &self,
122 output_count: u32,
123 catalog_reader: &CatalogReader,
124 worker_node_manager: &WorkerNodeSelector,
125 ) -> Result<ExchangeInfo> {
126 let exchange_info = ExchangeInfo {
127 mode: match self {
128 Distribution::Single => DistributionMode::Single,
129 Distribution::HashShard(_) => DistributionMode::Hash,
130 Distribution::SomeShard => DistributionMode::Single,
132 Distribution::Broadcast => DistributionMode::Broadcast,
133 Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
134 } as i32,
135 distribution: match self {
136 Distribution::Single => None,
137 Distribution::HashShard(key) => {
138 assert!(
139 !key.is_empty(),
140 "hash key should not be empty, use `Single` instead"
141 );
142 Some(PbDistribution::HashInfo(HashInfo {
143 output_count,
144 key: key.iter().map(|num| *num as u32).collect(),
145 }))
146 }
147 Distribution::SomeShard => None,
149 Distribution::Broadcast => None,
150 Distribution::UpstreamHashShard(key, table_id) => {
151 assert!(
152 !key.is_empty(),
153 "hash key should not be empty, use `Single` instead"
154 );
155
156 let vnode_mapping = worker_node_manager
157 .fragment_mapping(Self::get_fragment_id(catalog_reader, *table_id)?)?;
158
159 let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
160 .iter_unique()
161 .enumerate()
162 .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
163 .collect();
164
165 Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
166 vmap: vnode_mapping
167 .iter()
168 .map(|id| worker_slot_to_id_map[&id])
169 .collect_vec(),
170 key: key.iter().map(|num| *num as u32).collect(),
171 }))
172 }
173 },
174 };
175 Ok(exchange_info)
176 }
177
178 pub fn satisfies(&self, required: &RequiredDist) -> bool {
180 match required {
181 RequiredDist::Any => true,
182 RequiredDist::AnyShard => {
183 matches!(
184 self,
185 Distribution::SomeShard
186 | Distribution::HashShard(_)
187 | Distribution::UpstreamHashShard(_, _)
188 | Distribution::Broadcast
189 )
190 }
191 RequiredDist::ShardByKey(required_key) => match self {
192 Distribution::HashShard(hash_key)
193 | Distribution::UpstreamHashShard(hash_key, _) => {
194 hash_key.iter().all(|idx| required_key.contains(*idx))
195 }
196 _ => false,
197 },
198 RequiredDist::ShardByExactKey(required_key) => match self {
199 Distribution::HashShard(hash_key)
200 | Distribution::UpstreamHashShard(hash_key, _) => {
201 hash_key.len() == required_key.count_ones(..)
202 && hash_key.iter().all(|idx| required_key.contains(*idx))
203 }
204 _ => false,
205 },
206 RequiredDist::PhysicalDist(other) => self == other,
207 }
208 }
209
210 pub fn dist_column_indices(&self) -> &[usize] {
212 self.dist_column_indices_opt()
213 .unwrap_or_else(|| panic!("cannot obtain distribution columns for {self:?}"))
214 }
215
216 pub fn dist_column_indices_opt(&self) -> Option<&[usize]> {
218 match self {
219 Distribution::Single => Some(&[]),
220 Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
221 Some(dists)
222 }
223 Distribution::SomeShard | Distribution::Broadcast => None,
224 }
225 }
226
227 #[inline(always)]
228 fn get_fragment_id(catalog_reader: &CatalogReader, table_id: TableId) -> Result<FragmentId> {
229 catalog_reader
230 .read_guard()
231 .get_any_table_by_id(table_id)
232 .map(|table| table.fragment_id)
233 .map_err(Into::into)
234 }
235}
236
237impl fmt::Display for Distribution {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 f.write_str("[")?;
240 match self {
241 Self::Single => f.write_str("Single")?,
242 Self::SomeShard => f.write_str("SomeShard")?,
243 Self::Broadcast => f.write_str("Broadcast")?,
244 Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
245 for key in vec {
246 std::fmt::Debug::fmt(&key, f)?;
247 }
248 }
249 }
250 f.write_str("]")
251 }
252}
253
254pub struct DistributionDisplay<'a> {
255 pub distribution: &'a Distribution,
256 pub input_schema: &'a Schema,
257}
258
259impl DistributionDisplay<'_> {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 let that = self.distribution;
262 match that {
263 Distribution::Single => f.write_str("Single"),
264 Distribution::SomeShard => f.write_str("SomeShard"),
265 Distribution::Broadcast => f.write_str("Broadcast"),
266 Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
267 if let Distribution::HashShard(_) = that {
268 f.write_str("HashShard(")?;
269 } else {
270 f.write_str("UpstreamHashShard(")?;
271 }
272 for (pos, key) in vec.iter().copied().with_position() {
273 std::fmt::Debug::fmt(
274 &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
275 f,
276 )?;
277 match pos {
278 itertools::Position::First | itertools::Position::Middle => {
279 f.write_str(", ")?;
280 }
281 _ => {}
282 }
283 }
284 f.write_str(")")
285 }
286 }
287 }
288}
289
290impl fmt::Debug for DistributionDisplay<'_> {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 self.fmt(f)
293 }
294}
295
296impl fmt::Display for DistributionDisplay<'_> {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 self.fmt(f)
299 }
300}
301
302impl RequiredDist {
303 pub fn single() -> Self {
304 Self::PhysicalDist(Distribution::Single)
305 }
306
307 pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
308 let mut cols = FixedBitSet::with_capacity(tot_col_num);
309 for i in key {
310 cols.insert(*i);
311 }
312 assert!(!cols.is_clear());
313 Self::ShardByKey(cols)
314 }
315
316 pub fn shard_by_exact_key(tot_col_num: usize, key: &[usize]) -> Self {
317 let mut cols = FixedBitSet::with_capacity(tot_col_num);
318 for i in key {
319 cols.insert(*i);
320 }
321 assert!(!cols.is_clear());
322 Self::ShardByExactKey(cols)
323 }
324
325 pub fn hash_shard(key: &[usize]) -> Self {
326 assert!(!key.is_empty());
327 Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
328 }
329
330 pub fn batch_enforce_if_not_satisfies(
331 &self,
332 mut plan: BatchPlanRef,
333 required_order: &Order,
334 ) -> Result<BatchPlanRef> {
335 plan = required_order.enforce_if_not_satisfies(plan)?;
336 if !plan.distribution().satisfies(self) {
337 Ok(self.batch_enforce(plan, required_order))
338 } else {
339 Ok(plan)
340 }
341 }
342
343 pub fn streaming_enforce_if_not_satisfies(&self, plan: StreamPlanRef) -> Result<StreamPlanRef> {
344 if !plan.distribution().satisfies(self) {
345 Ok(self.stream_enforce(plan))
346 } else {
347 Ok(plan)
348 }
349 }
350
351 pub fn no_shuffle(plan: StreamPlanRef) -> StreamPlanRef {
352 StreamExchange::new_no_shuffle(plan).into()
353 }
354
355 pub fn satisfies(&self, required: &RequiredDist) -> bool {
357 match self {
358 RequiredDist::Any => matches!(required, RequiredDist::Any),
359 RequiredDist::AnyShard => {
360 matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
361 }
362 RequiredDist::ShardByKey(key) => match required {
363 RequiredDist::Any | RequiredDist::AnyShard => true,
364 RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
365 RequiredDist::ShardByExactKey(required_key) => {
366 key == required_key && key.count_ones(..) == 1
367 }
368 _ => false,
369 },
370 RequiredDist::ShardByExactKey(key) => match required {
371 RequiredDist::Any | RequiredDist::AnyShard => true,
372 RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
373 RequiredDist::ShardByExactKey(required_key) => key == required_key,
374 _ => false,
375 },
376 RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
377 }
378 }
379
380 pub fn batch_enforce(&self, plan: BatchPlanRef, required_order: &Order) -> BatchPlanRef {
381 let dist = self.to_dist();
382 BatchExchange::new(plan, required_order.clone(), dist).into()
383 }
384
385 pub fn stream_enforce(&self, plan: StreamPlanRef) -> StreamPlanRef {
386 let dist = self.to_dist();
387 StreamExchange::new(plan, dist).into()
388 }
389
390 fn to_dist(&self) -> Distribution {
391 match self {
392 RequiredDist::Any => unreachable!(),
395 RequiredDist::AnyShard => todo!(),
397 RequiredDist::ShardByKey(required_keys) => {
398 Distribution::HashShard(required_keys.ones().collect())
399 }
400 RequiredDist::ShardByExactKey(required_keys) => {
401 Distribution::HashShard(required_keys.ones().collect())
402 }
403 RequiredDist::PhysicalDist(dist) => dist.clone(),
404 }
405 }
406}
407
408impl StreamPlanRef {
409 pub fn enforce_concrete_distribution(self) -> Self {
412 match self.distribution() {
413 Distribution::SomeShard => {
414 RequiredDist::shard_by_key(self.schema().len(), self.expect_stream_key())
415 .stream_enforce(self)
416 }
417 _ => self,
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::{Distribution, RequiredDist};
425
426 #[test]
427 fn hash_shard_satisfy() {
428 let d1 = Distribution::HashShard(vec![0, 1]);
429 let d2 = Distribution::HashShard(vec![1, 0]);
430 let d3 = Distribution::HashShard(vec![0]);
431 let d4 = Distribution::HashShard(vec![1]);
432
433 let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
434 let r3 = RequiredDist::shard_by_key(2, &[0]);
435 let r4 = RequiredDist::shard_by_key(2, &[1]);
436 let r_exact = RequiredDist::shard_by_exact_key(2, &[0, 1]);
437 let r_exact_single = RequiredDist::shard_by_exact_key(2, &[0]);
438 assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
439 assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
440 assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
441 assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
442
443 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
444 assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
445 assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
446
447 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
448 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
449 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
450 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
451
452 assert!(d1.satisfies(&r1));
453 assert!(d2.satisfies(&r1));
454 assert!(d3.satisfies(&r1));
455 assert!(d4.satisfies(&r1));
456
457 assert!(!d1.satisfies(&r3));
458 assert!(!d2.satisfies(&r3));
459 assert!(d3.satisfies(&r3));
460 assert!(!d4.satisfies(&r3));
461
462 assert!(!d1.satisfies(&r4));
463 assert!(!d2.satisfies(&r4));
464 assert!(!d3.satisfies(&r4));
465 assert!(d4.satisfies(&r4));
466
467 assert!(d1.satisfies(&r_exact));
468 assert!(d2.satisfies(&r_exact));
469 assert!(!d3.satisfies(&r_exact));
470 assert!(!d4.satisfies(&r_exact));
471
472 assert!(r3.satisfies(&r1));
473 assert!(r4.satisfies(&r1));
474 assert!(!r1.satisfies(&r3));
475 assert!(!r1.satisfies(&r4));
476 assert!(!r3.satisfies(&r4));
477 assert!(!r4.satisfies(&r3));
478
479 assert!(r_exact.satisfies(&r1));
480 assert!(!r1.satisfies(&r_exact));
481 assert!(!r3.satisfies(&r_exact));
482 assert!(!r_exact.satisfies(&r3));
483
484 assert!(r3.satisfies(&r_exact_single));
485 assert!(r_exact_single.satisfies(&r3));
486 }
487}