risingwave_frontend/optimizer/property/
distribution.rs1use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57 ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69 Single,
75 SomeShard,
78 HashShard(Vec<usize>),
82 UpstreamHashShard(Vec<usize>, TableId),
96 Broadcast,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103 Any,
105 AnyShard,
107 ShardByKey(FixedBitSet),
111 PhysicalDist(Distribution),
113}
114
115impl Distribution {
116 pub fn to_prost(
117 &self,
118 output_count: u32,
119 catalog_reader: &CatalogReader,
120 worker_node_manager: &WorkerNodeSelector,
121 ) -> Result<ExchangeInfo> {
122 let exchange_info = ExchangeInfo {
123 mode: match self {
124 Distribution::Single => DistributionMode::Single,
125 Distribution::HashShard(_) => DistributionMode::Hash,
126 Distribution::SomeShard => DistributionMode::Single,
128 Distribution::Broadcast => DistributionMode::Broadcast,
129 Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
130 } as i32,
131 distribution: match self {
132 Distribution::Single => None,
133 Distribution::HashShard(key) => {
134 assert!(
135 !key.is_empty(),
136 "hash key should not be empty, use `Single` instead"
137 );
138 Some(PbDistribution::HashInfo(HashInfo {
139 output_count,
140 key: key.iter().map(|num| *num as u32).collect(),
141 }))
142 }
143 Distribution::SomeShard => None,
145 Distribution::Broadcast => None,
146 Distribution::UpstreamHashShard(key, table_id) => {
147 assert!(
148 !key.is_empty(),
149 "hash key should not be empty, use `Single` instead"
150 );
151
152 let vnode_mapping = worker_node_manager
153 .fragment_mapping(Self::get_fragment_id(catalog_reader, table_id)?)?;
154
155 let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
156 .iter_unique()
157 .enumerate()
158 .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
159 .collect();
160
161 Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
162 vmap: vnode_mapping
163 .iter()
164 .map(|id| worker_slot_to_id_map[&id])
165 .collect_vec(),
166 key: key.iter().map(|num| *num as u32).collect(),
167 }))
168 }
169 },
170 };
171 Ok(exchange_info)
172 }
173
174 pub fn satisfies(&self, required: &RequiredDist) -> bool {
176 match required {
177 RequiredDist::Any => true,
178 RequiredDist::AnyShard => {
179 matches!(
180 self,
181 Distribution::SomeShard
182 | Distribution::HashShard(_)
183 | Distribution::UpstreamHashShard(_, _)
184 | Distribution::Broadcast
185 )
186 }
187 RequiredDist::ShardByKey(required_key) => match self {
188 Distribution::HashShard(hash_key)
189 | Distribution::UpstreamHashShard(hash_key, _) => {
190 hash_key.iter().all(|idx| required_key.contains(*idx))
191 }
192 _ => false,
193 },
194 RequiredDist::PhysicalDist(other) => self == other,
195 }
196 }
197
198 pub fn dist_column_indices(&self) -> &[usize] {
200 self.dist_column_indices_opt()
201 .unwrap_or_else(|| panic!("cannot obtain distribution columns for {self:?}"))
202 }
203
204 pub fn dist_column_indices_opt(&self) -> Option<&[usize]> {
206 match self {
207 Distribution::Single => Some(&[]),
208 Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
209 Some(dists)
210 }
211 Distribution::SomeShard | Distribution::Broadcast => None,
212 }
213 }
214
215 #[inline(always)]
216 fn get_fragment_id(catalog_reader: &CatalogReader, table_id: &TableId) -> Result<FragmentId> {
217 catalog_reader
218 .read_guard()
219 .get_any_table_by_id(table_id)
220 .map(|table| table.fragment_id)
221 .map_err(Into::into)
222 }
223}
224
225impl fmt::Display for Distribution {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 f.write_str("[")?;
228 match self {
229 Self::Single => f.write_str("Single")?,
230 Self::SomeShard => f.write_str("SomeShard")?,
231 Self::Broadcast => f.write_str("Broadcast")?,
232 Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
233 for key in vec {
234 std::fmt::Debug::fmt(&key, f)?;
235 }
236 }
237 }
238 f.write_str("]")
239 }
240}
241
242pub struct DistributionDisplay<'a> {
243 pub distribution: &'a Distribution,
244 pub input_schema: &'a Schema,
245}
246
247impl DistributionDisplay<'_> {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 let that = self.distribution;
250 match that {
251 Distribution::Single => f.write_str("Single"),
252 Distribution::SomeShard => f.write_str("SomeShard"),
253 Distribution::Broadcast => f.write_str("Broadcast"),
254 Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
255 if let Distribution::HashShard(_) = that {
256 f.write_str("HashShard(")?;
257 } else {
258 f.write_str("UpstreamHashShard(")?;
259 }
260 for (pos, key) in vec.iter().copied().with_position() {
261 std::fmt::Debug::fmt(
262 &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
263 f,
264 )?;
265 match pos {
266 itertools::Position::First | itertools::Position::Middle => {
267 f.write_str(", ")?;
268 }
269 _ => {}
270 }
271 }
272 f.write_str(")")
273 }
274 }
275 }
276}
277
278impl fmt::Debug for DistributionDisplay<'_> {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 self.fmt(f)
281 }
282}
283
284impl fmt::Display for DistributionDisplay<'_> {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 self.fmt(f)
287 }
288}
289
290impl RequiredDist {
291 pub fn single() -> Self {
292 Self::PhysicalDist(Distribution::Single)
293 }
294
295 pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
296 let mut cols = FixedBitSet::with_capacity(tot_col_num);
297 for i in key {
298 cols.insert(*i);
299 }
300 assert!(!cols.is_clear());
301 Self::ShardByKey(cols)
302 }
303
304 pub fn hash_shard(key: &[usize]) -> Self {
305 assert!(!key.is_empty());
306 Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
307 }
308
309 pub fn batch_enforce_if_not_satisfies(
310 &self,
311 mut plan: BatchPlanRef,
312 required_order: &Order,
313 ) -> Result<BatchPlanRef> {
314 plan = required_order.enforce_if_not_satisfies(plan)?;
315 if !plan.distribution().satisfies(self) {
316 Ok(self.batch_enforce(plan, required_order))
317 } else {
318 Ok(plan)
319 }
320 }
321
322 pub fn streaming_enforce_if_not_satisfies(&self, plan: StreamPlanRef) -> Result<StreamPlanRef> {
323 if !plan.distribution().satisfies(self) {
324 Ok(self.stream_enforce(plan))
325 } else {
326 Ok(plan)
327 }
328 }
329
330 pub fn no_shuffle(plan: StreamPlanRef) -> StreamPlanRef {
331 StreamExchange::new_no_shuffle(plan).into()
332 }
333
334 pub fn satisfies(&self, required: &RequiredDist) -> bool {
336 match self {
337 RequiredDist::Any => matches!(required, RequiredDist::Any),
338 RequiredDist::AnyShard => {
339 matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
340 }
341 RequiredDist::ShardByKey(key) => match required {
342 RequiredDist::Any | RequiredDist::AnyShard => true,
343 RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
344 _ => false,
345 },
346 RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
347 }
348 }
349
350 pub fn batch_enforce(&self, plan: BatchPlanRef, required_order: &Order) -> BatchPlanRef {
351 let dist = self.to_dist();
352 BatchExchange::new(plan, required_order.clone(), dist).into()
353 }
354
355 pub fn stream_enforce(&self, plan: StreamPlanRef) -> StreamPlanRef {
356 let dist = self.to_dist();
357 StreamExchange::new(plan, dist).into()
358 }
359
360 fn to_dist(&self) -> Distribution {
361 match self {
362 RequiredDist::Any => unreachable!(),
365 RequiredDist::AnyShard => todo!(),
367 RequiredDist::ShardByKey(required_keys) => {
368 Distribution::HashShard(required_keys.ones().collect())
369 }
370 RequiredDist::PhysicalDist(dist) => dist.clone(),
371 }
372 }
373}
374
375impl StreamPlanRef {
376 pub fn enforce_concrete_distribution(self) -> Self {
379 match self.distribution() {
380 Distribution::SomeShard => {
381 RequiredDist::shard_by_key(self.schema().len(), self.expect_stream_key())
382 .stream_enforce(self)
383 }
384 _ => self,
385 }
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::{Distribution, RequiredDist};
392
393 #[test]
394 fn hash_shard_satisfy() {
395 let d1 = Distribution::HashShard(vec![0, 1]);
396 let d2 = Distribution::HashShard(vec![1, 0]);
397 let d3 = Distribution::HashShard(vec![0]);
398 let d4 = Distribution::HashShard(vec![1]);
399
400 let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
401 let r3 = RequiredDist::shard_by_key(2, &[0]);
402 let r4 = RequiredDist::shard_by_key(2, &[1]);
403 assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
404 assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
405 assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
406 assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
407
408 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
409 assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
410 assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
411
412 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
413 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
414 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
415 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
416
417 assert!(d1.satisfies(&r1));
418 assert!(d2.satisfies(&r1));
419 assert!(d3.satisfies(&r1));
420 assert!(d4.satisfies(&r1));
421
422 assert!(!d1.satisfies(&r3));
423 assert!(!d2.satisfies(&r3));
424 assert!(d3.satisfies(&r3));
425 assert!(!d4.satisfies(&r3));
426
427 assert!(!d1.satisfies(&r4));
428 assert!(!d2.satisfies(&r4));
429 assert!(!d3.satisfies(&r4));
430 assert!(d4.satisfies(&r4));
431
432 assert!(r3.satisfies(&r1));
433 assert!(r4.satisfies(&r1));
434 assert!(!r1.satisfies(&r3));
435 assert!(!r1.satisfies(&r4));
436 assert!(!r3.satisfies(&r4));
437 assert!(!r4.satisfies(&r3));
438 }
439}