risingwave_frontend/optimizer/property/
distribution.rs1use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57 ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69 Single,
75 SomeShard,
78 HashShard(Vec<usize>),
82 UpstreamHashShard(Vec<usize>, TableId),
96 Broadcast,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103 Any,
105 AnyShard,
107 ShardByKey(FixedBitSet),
111 PhysicalDist(Distribution),
113}
114
115impl Distribution {
116 pub fn to_prost(
117 &self,
118 output_count: u32,
119 catalog_reader: &CatalogReader,
120 worker_node_manager: &WorkerNodeSelector,
121 ) -> Result<ExchangeInfo> {
122 let exchange_info = ExchangeInfo {
123 mode: match self {
124 Distribution::Single => DistributionMode::Single,
125 Distribution::HashShard(_) => DistributionMode::Hash,
126 Distribution::SomeShard => DistributionMode::Single,
128 Distribution::Broadcast => DistributionMode::Broadcast,
129 Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
130 } as i32,
131 distribution: match self {
132 Distribution::Single => None,
133 Distribution::HashShard(key) => {
134 assert!(
135 !key.is_empty(),
136 "hash key should not be empty, use `Single` instead"
137 );
138 Some(PbDistribution::HashInfo(HashInfo {
139 output_count,
140 key: key.iter().map(|num| *num as u32).collect(),
141 }))
142 }
143 Distribution::SomeShard => None,
145 Distribution::Broadcast => None,
146 Distribution::UpstreamHashShard(key, table_id) => {
147 assert!(
148 !key.is_empty(),
149 "hash key should not be empty, use `Single` instead"
150 );
151
152 let vnode_mapping = worker_node_manager
153 .fragment_mapping(Self::get_fragment_id(catalog_reader, table_id)?)?;
154
155 let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
156 .iter_unique()
157 .enumerate()
158 .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
159 .collect();
160
161 Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
162 vmap: vnode_mapping
163 .iter()
164 .map(|id| worker_slot_to_id_map[&id])
165 .collect_vec(),
166 key: key.iter().map(|num| *num as u32).collect(),
167 }))
168 }
169 },
170 };
171 Ok(exchange_info)
172 }
173
174 pub fn satisfies(&self, required: &RequiredDist) -> bool {
176 match required {
177 RequiredDist::Any => true,
178 RequiredDist::AnyShard => {
179 matches!(
180 self,
181 Distribution::SomeShard
182 | Distribution::HashShard(_)
183 | Distribution::UpstreamHashShard(_, _)
184 | Distribution::Broadcast
185 )
186 }
187 RequiredDist::ShardByKey(required_key) => match self {
188 Distribution::HashShard(hash_key)
189 | Distribution::UpstreamHashShard(hash_key, _) => {
190 hash_key.iter().all(|idx| required_key.contains(*idx))
191 }
192 _ => false,
193 },
194 RequiredDist::PhysicalDist(other) => self == other,
195 }
196 }
197
198 pub fn dist_column_indices(&self) -> &[usize] {
202 match self {
203 Distribution::Single => &[],
204 Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => dists,
205 Distribution::SomeShard | Distribution::Broadcast => {
206 panic!("cannot obtain distribution columns for {self:?}")
207 }
208 }
209 }
210
211 #[inline(always)]
212 fn get_fragment_id(catalog_reader: &CatalogReader, table_id: &TableId) -> Result<FragmentId> {
213 catalog_reader
214 .read_guard()
215 .get_any_table_by_id(table_id)
216 .map(|table| table.fragment_id)
217 .map_err(Into::into)
218 }
219}
220
221impl fmt::Display for Distribution {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 f.write_str("[")?;
224 match self {
225 Self::Single => f.write_str("Single")?,
226 Self::SomeShard => f.write_str("SomeShard")?,
227 Self::Broadcast => f.write_str("Broadcast")?,
228 Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
229 for key in vec {
230 std::fmt::Debug::fmt(&key, f)?;
231 }
232 }
233 }
234 f.write_str("]")
235 }
236}
237
238pub struct DistributionDisplay<'a> {
239 pub distribution: &'a Distribution,
240 pub input_schema: &'a Schema,
241}
242
243impl DistributionDisplay<'_> {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 let that = self.distribution;
246 match that {
247 Distribution::Single => f.write_str("Single"),
248 Distribution::SomeShard => f.write_str("SomeShard"),
249 Distribution::Broadcast => f.write_str("Broadcast"),
250 Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
251 if let Distribution::HashShard(_) = that {
252 f.write_str("HashShard(")?;
253 } else {
254 f.write_str("UpstreamHashShard(")?;
255 }
256 for (pos, key) in vec.iter().copied().with_position() {
257 std::fmt::Debug::fmt(
258 &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
259 f,
260 )?;
261 match pos {
262 itertools::Position::First | itertools::Position::Middle => {
263 f.write_str(", ")?;
264 }
265 _ => {}
266 }
267 }
268 f.write_str(")")
269 }
270 }
271 }
272}
273
274impl fmt::Debug for DistributionDisplay<'_> {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 self.fmt(f)
277 }
278}
279
280impl fmt::Display for DistributionDisplay<'_> {
281 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 self.fmt(f)
283 }
284}
285
286impl RequiredDist {
287 pub fn single() -> Self {
288 Self::PhysicalDist(Distribution::Single)
289 }
290
291 pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
292 let mut cols = FixedBitSet::with_capacity(tot_col_num);
293 for i in key {
294 cols.insert(*i);
295 }
296 assert!(!cols.is_clear());
297 Self::ShardByKey(cols)
298 }
299
300 pub fn hash_shard(key: &[usize]) -> Self {
301 assert!(!key.is_empty());
302 Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
303 }
304
305 pub fn batch_enforce_if_not_satisfies(
306 &self,
307 mut plan: BatchPlanRef,
308 required_order: &Order,
309 ) -> Result<BatchPlanRef> {
310 plan = required_order.enforce_if_not_satisfies(plan)?;
311 if !plan.distribution().satisfies(self) {
312 Ok(self.batch_enforce(plan, required_order))
313 } else {
314 Ok(plan)
315 }
316 }
317
318 pub fn streaming_enforce_if_not_satisfies(&self, plan: StreamPlanRef) -> Result<StreamPlanRef> {
319 if !plan.distribution().satisfies(self) {
320 Ok(self.stream_enforce(plan))
321 } else {
322 Ok(plan)
323 }
324 }
325
326 pub fn no_shuffle(plan: StreamPlanRef) -> StreamPlanRef {
327 StreamExchange::new_no_shuffle(plan).into()
328 }
329
330 pub fn satisfies(&self, required: &RequiredDist) -> bool {
332 match self {
333 RequiredDist::Any => matches!(required, RequiredDist::Any),
334 RequiredDist::AnyShard => {
335 matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
336 }
337 RequiredDist::ShardByKey(key) => match required {
338 RequiredDist::Any | RequiredDist::AnyShard => true,
339 RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
340 _ => false,
341 },
342 RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
343 }
344 }
345
346 pub fn batch_enforce(&self, plan: BatchPlanRef, required_order: &Order) -> BatchPlanRef {
347 let dist = self.to_dist();
348 BatchExchange::new(plan, required_order.clone(), dist).into()
349 }
350
351 pub fn stream_enforce(&self, plan: StreamPlanRef) -> StreamPlanRef {
352 let dist = self.to_dist();
353 StreamExchange::new(plan, dist).into()
354 }
355
356 fn to_dist(&self) -> Distribution {
357 match self {
358 RequiredDist::Any => unreachable!(),
361 RequiredDist::AnyShard => todo!(),
363 RequiredDist::ShardByKey(required_keys) => {
364 Distribution::HashShard(required_keys.ones().collect())
365 }
366 RequiredDist::PhysicalDist(dist) => dist.clone(),
367 }
368 }
369}
370
371impl StreamPlanRef {
372 pub fn enforce_concrete_distribution(self) -> Self {
375 match self.distribution() {
376 Distribution::SomeShard => {
377 RequiredDist::shard_by_key(self.schema().len(), self.expect_stream_key())
378 .stream_enforce(self)
379 }
380 _ => self,
381 }
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::{Distribution, RequiredDist};
388
389 #[test]
390 fn hash_shard_satisfy() {
391 let d1 = Distribution::HashShard(vec![0, 1]);
392 let d2 = Distribution::HashShard(vec![1, 0]);
393 let d3 = Distribution::HashShard(vec![0]);
394 let d4 = Distribution::HashShard(vec![1]);
395
396 let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
397 let r3 = RequiredDist::shard_by_key(2, &[0]);
398 let r4 = RequiredDist::shard_by_key(2, &[1]);
399 assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
400 assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
401 assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
402 assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
403
404 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
405 assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
406 assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
407
408 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
409 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
410 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
411 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
412
413 assert!(d1.satisfies(&r1));
414 assert!(d2.satisfies(&r1));
415 assert!(d3.satisfies(&r1));
416 assert!(d4.satisfies(&r1));
417
418 assert!(!d1.satisfies(&r3));
419 assert!(!d2.satisfies(&r3));
420 assert!(d3.satisfies(&r3));
421 assert!(!d4.satisfies(&r3));
422
423 assert!(!d1.satisfies(&r4));
424 assert!(!d2.satisfies(&r4));
425 assert!(!d3.satisfies(&r4));
426 assert!(d4.satisfies(&r4));
427
428 assert!(r3.satisfies(&r1));
429 assert!(r4.satisfies(&r1));
430 assert!(!r1.satisfies(&r3));
431 assert!(!r1.satisfies(&r4));
432 assert!(!r3.satisfies(&r4));
433 assert!(!r4.satisfies(&r3));
434 }
435}