risingwave_frontend/optimizer/property/
distribution.rs1use std::collections::HashMap;
46use std::fmt;
47use std::fmt::Debug;
48
49use fixedbitset::FixedBitSet;
50use generic::PhysicalPlanRef;
51use itertools::Itertools;
52use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
53use risingwave_common::catalog::{FieldDisplay, Schema, TableId};
54use risingwave_common::hash::WorkerSlotId;
55use risingwave_pb::batch_plan::ExchangeInfo;
56use risingwave_pb::batch_plan::exchange_info::{
57 ConsistentHashInfo, Distribution as PbDistribution, DistributionMode, HashInfo,
58};
59
60use super::super::plan_node::*;
61use crate::catalog::FragmentId;
62use crate::catalog::catalog_service::CatalogReader;
63use crate::error::Result;
64use crate::optimizer::property::Order;
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum Distribution {
69 Single,
75 SomeShard,
78 HashShard(Vec<usize>),
82 UpstreamHashShard(Vec<usize>, TableId),
96 Broadcast,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum RequiredDist {
103 Any,
105 AnyShard,
107 ShardByKey(FixedBitSet),
111 PhysicalDist(Distribution),
113}
114
115impl Distribution {
116 pub fn to_prost(
117 &self,
118 output_count: u32,
119 catalog_reader: &CatalogReader,
120 worker_node_manager: &WorkerNodeSelector,
121 ) -> Result<ExchangeInfo> {
122 let exchange_info = ExchangeInfo {
123 mode: match self {
124 Distribution::Single => DistributionMode::Single,
125 Distribution::HashShard(_) => DistributionMode::Hash,
126 Distribution::SomeShard => DistributionMode::Single,
128 Distribution::Broadcast => DistributionMode::Broadcast,
129 Distribution::UpstreamHashShard(_, _) => DistributionMode::ConsistentHash,
130 } as i32,
131 distribution: match self {
132 Distribution::Single => None,
133 Distribution::HashShard(key) => {
134 assert!(
135 !key.is_empty(),
136 "hash key should not be empty, use `Single` instead"
137 );
138 Some(PbDistribution::HashInfo(HashInfo {
139 output_count,
140 key: key.iter().map(|num| *num as u32).collect(),
141 }))
142 }
143 Distribution::SomeShard => None,
145 Distribution::Broadcast => None,
146 Distribution::UpstreamHashShard(key, table_id) => {
147 assert!(
148 !key.is_empty(),
149 "hash key should not be empty, use `Single` instead"
150 );
151
152 let vnode_mapping = worker_node_manager
153 .fragment_mapping(Self::get_fragment_id(catalog_reader, table_id)?)?;
154
155 let worker_slot_to_id_map: HashMap<WorkerSlotId, u32> = vnode_mapping
156 .iter_unique()
157 .enumerate()
158 .map(|(i, worker_slot_id)| (worker_slot_id, i as u32))
159 .collect();
160
161 Some(PbDistribution::ConsistentHashInfo(ConsistentHashInfo {
162 vmap: vnode_mapping
163 .iter()
164 .map(|id| worker_slot_to_id_map[&id])
165 .collect_vec(),
166 key: key.iter().map(|num| *num as u32).collect(),
167 }))
168 }
169 },
170 };
171 Ok(exchange_info)
172 }
173
174 pub fn satisfies(&self, required: &RequiredDist) -> bool {
176 match required {
177 RequiredDist::Any => true,
178 RequiredDist::AnyShard => {
179 matches!(
180 self,
181 Distribution::SomeShard
182 | Distribution::HashShard(_)
183 | Distribution::UpstreamHashShard(_, _)
184 | Distribution::Broadcast
185 )
186 }
187 RequiredDist::ShardByKey(required_key) => match self {
188 Distribution::HashShard(hash_key)
189 | Distribution::UpstreamHashShard(hash_key, _) => {
190 hash_key.iter().all(|idx| required_key.contains(*idx))
191 }
192 _ => false,
193 },
194 RequiredDist::PhysicalDist(other) => self == other,
195 }
196 }
197
198 pub fn dist_column_indices(&self) -> &[usize] {
201 match self {
202 Distribution::Single | Distribution::SomeShard | Distribution::Broadcast => {
203 Default::default()
204 }
205 Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => dists,
206 }
207 }
208
209 #[inline(always)]
210 fn get_fragment_id(catalog_reader: &CatalogReader, table_id: &TableId) -> Result<FragmentId> {
211 catalog_reader
212 .read_guard()
213 .get_any_table_by_id(table_id)
214 .map(|table| table.fragment_id)
215 .map_err(Into::into)
216 }
217}
218
219impl fmt::Display for Distribution {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 f.write_str("[")?;
222 match self {
223 Self::Single => f.write_str("Single")?,
224 Self::SomeShard => f.write_str("SomeShard")?,
225 Self::Broadcast => f.write_str("Broadcast")?,
226 Self::HashShard(vec) | Self::UpstreamHashShard(vec, _) => {
227 for key in vec {
228 std::fmt::Debug::fmt(&key, f)?;
229 }
230 }
231 }
232 f.write_str("]")
233 }
234}
235
236pub struct DistributionDisplay<'a> {
237 pub distribution: &'a Distribution,
238 pub input_schema: &'a Schema,
239}
240
241impl DistributionDisplay<'_> {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 let that = self.distribution;
244 match that {
245 Distribution::Single => f.write_str("Single"),
246 Distribution::SomeShard => f.write_str("SomeShard"),
247 Distribution::Broadcast => f.write_str("Broadcast"),
248 Distribution::HashShard(vec) | Distribution::UpstreamHashShard(vec, _) => {
249 if let Distribution::HashShard(_) = that {
250 f.write_str("HashShard(")?;
251 } else {
252 f.write_str("UpstreamHashShard(")?;
253 }
254 for (pos, key) in vec.iter().copied().with_position() {
255 std::fmt::Debug::fmt(
256 &FieldDisplay(self.input_schema.fields.get(key).unwrap()),
257 f,
258 )?;
259 match pos {
260 itertools::Position::First | itertools::Position::Middle => {
261 f.write_str(", ")?;
262 }
263 _ => {}
264 }
265 }
266 f.write_str(")")
267 }
268 }
269 }
270}
271
272impl fmt::Debug for DistributionDisplay<'_> {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 self.fmt(f)
275 }
276}
277
278impl fmt::Display for DistributionDisplay<'_> {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 self.fmt(f)
281 }
282}
283
284impl RequiredDist {
285 pub fn single() -> Self {
286 Self::PhysicalDist(Distribution::Single)
287 }
288
289 pub fn shard_by_key(tot_col_num: usize, key: &[usize]) -> Self {
290 let mut cols = FixedBitSet::with_capacity(tot_col_num);
291 for i in key {
292 cols.insert(*i);
293 }
294 assert!(!cols.is_clear());
295 Self::ShardByKey(cols)
296 }
297
298 pub fn hash_shard(key: &[usize]) -> Self {
299 assert!(!key.is_empty());
300 Self::PhysicalDist(Distribution::HashShard(key.to_vec()))
301 }
302
303 pub fn enforce_if_not_satisfies(
304 &self,
305 mut plan: PlanRef,
306 required_order: &Order,
307 ) -> Result<PlanRef> {
308 if let Convention::Batch = plan.convention() {
309 plan = required_order.enforce_if_not_satisfies(plan)?;
310 }
311 if !plan.distribution().satisfies(self) {
312 Ok(self.enforce(plan, required_order))
313 } else {
314 Ok(plan)
315 }
316 }
317
318 pub fn no_shuffle(plan: PlanRef) -> PlanRef {
319 match plan.convention() {
320 Convention::Stream => StreamExchange::new_no_shuffle(plan).into(),
321 Convention::Logical | Convention::Batch => unreachable!(),
322 }
323 }
324
325 pub fn satisfies(&self, required: &RequiredDist) -> bool {
327 match self {
328 RequiredDist::Any => matches!(required, RequiredDist::Any),
329 RequiredDist::AnyShard => {
330 matches!(required, RequiredDist::Any | RequiredDist::AnyShard)
331 }
332 RequiredDist::ShardByKey(key) => match required {
333 RequiredDist::Any | RequiredDist::AnyShard => true,
334 RequiredDist::ShardByKey(required_key) => key.is_subset(required_key),
335 _ => false,
336 },
337 RequiredDist::PhysicalDist(dist) => dist.satisfies(required),
338 }
339 }
340
341 pub fn enforce(&self, plan: PlanRef, required_order: &Order) -> PlanRef {
342 let dist = self.to_dist();
343 match plan.convention() {
344 Convention::Batch => BatchExchange::new(plan, required_order.clone(), dist).into(),
345 Convention::Stream => StreamExchange::new(plan, dist).into(),
346 _ => unreachable!(),
347 }
348 }
349
350 fn to_dist(&self) -> Distribution {
351 match self {
352 RequiredDist::Any => unreachable!(),
355 RequiredDist::AnyShard => todo!(),
357 RequiredDist::ShardByKey(required_keys) => {
358 Distribution::HashShard(required_keys.ones().collect())
359 }
360 RequiredDist::PhysicalDist(dist) => dist.clone(),
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::{Distribution, RequiredDist};
368
369 #[test]
370 fn hash_shard_satisfy() {
371 let d1 = Distribution::HashShard(vec![0, 1]);
372 let d2 = Distribution::HashShard(vec![1, 0]);
373 let d3 = Distribution::HashShard(vec![0]);
374 let d4 = Distribution::HashShard(vec![1]);
375
376 let r1 = RequiredDist::shard_by_key(2, &[0, 1]);
377 let r3 = RequiredDist::shard_by_key(2, &[0]);
378 let r4 = RequiredDist::shard_by_key(2, &[1]);
379 assert!(d1.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
380 assert!(d2.satisfies(&RequiredDist::PhysicalDist(d2.clone())));
381 assert!(d3.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
382 assert!(d4.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
383
384 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
385 assert!(!d3.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
386 assert!(!d4.satisfies(&RequiredDist::PhysicalDist(d1.clone())));
387
388 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
389 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d3.clone())));
390 assert!(!d1.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
391 assert!(!d2.satisfies(&RequiredDist::PhysicalDist(d4.clone())));
392
393 assert!(d1.satisfies(&r1));
394 assert!(d2.satisfies(&r1));
395 assert!(d3.satisfies(&r1));
396 assert!(d4.satisfies(&r1));
397
398 assert!(!d1.satisfies(&r3));
399 assert!(!d2.satisfies(&r3));
400 assert!(d3.satisfies(&r3));
401 assert!(!d4.satisfies(&r3));
402
403 assert!(!d1.satisfies(&r4));
404 assert!(!d2.satisfies(&r4));
405 assert!(!d3.satisfies(&r4));
406 assert!(d4.satisfies(&r4));
407
408 assert!(r3.satisfies(&r1));
409 assert!(r4.satisfies(&r1));
410 assert!(!r1.satisfies(&r3));
411 assert!(!r1.satisfies(&r4));
412 assert!(!r3.satisfies(&r4));
413 assert!(!r4.satisfies(&r3));
414 }
415}