1use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
16use std::ops::{AddAssign, Deref};
17use std::sync::Arc;
18
19use itertools::Itertools;
20use risingwave_common::bitmap::Bitmap;
21use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask, TableId};
22use risingwave_common::hash::{IsSingleton, VirtualNode, VnodeCount, VnodeCountCompat};
23use risingwave_common::id::JobId;
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::system_param::adaptive_parallelism_strategy::parse_strategy;
26use risingwave_common::util::stream_graph_visitor::{self, visit_stream_node_body};
27use risingwave_meta_model::{DispatcherType, SourceId, StreamingParallelism, WorkerId, fragment};
28use risingwave_pb::catalog::Table;
29use risingwave_pb::common::ActorInfo;
30use risingwave_pb::id::SubscriberId;
31use risingwave_pb::meta::table_fragments::fragment::{
32 FragmentDistributionType, PbFragmentDistributionType,
33};
34use risingwave_pb::meta::table_fragments::{PbActorStatus, PbFragment, State};
35use risingwave_pb::meta::table_parallelism::{
36 FixedParallelism, Parallelism, PbAdaptiveParallelism, PbCustomParallelism, PbFixedParallelism,
37 PbParallelism,
38};
39use risingwave_pb::meta::{PbTableFragments, PbTableParallelism};
40use risingwave_pb::plan_common::PbExprContext;
41use risingwave_pb::stream_plan::stream_node::NodeBody;
42use risingwave_pb::stream_plan::{
43 DispatchStrategy, Dispatcher, PbDispatchOutputMapping, PbDispatcher, PbStreamActor,
44 PbStreamContext, StreamNode,
45};
46use strum::Display;
47
48use super::{ActorId, FragmentId};
49
50#[derive(Debug, Copy, Clone, Eq, PartialEq)]
52pub enum TableParallelism {
53 Adaptive,
55 Fixed(usize),
58 Custom,
65}
66
67impl From<PbTableParallelism> for TableParallelism {
68 fn from(value: PbTableParallelism) -> Self {
69 use Parallelism::*;
70 match &value.parallelism {
71 Some(Fixed(FixedParallelism { parallelism: n })) => Self::Fixed(*n as usize),
72 Some(Adaptive(_)) | Some(Auto(_)) => Self::Adaptive,
73 Some(Custom(_)) => Self::Custom,
74 _ => unreachable!(),
75 }
76 }
77}
78
79impl From<TableParallelism> for PbTableParallelism {
80 fn from(value: TableParallelism) -> Self {
81 use TableParallelism::*;
82
83 let parallelism = match value {
84 Adaptive => PbParallelism::Adaptive(PbAdaptiveParallelism {}),
85 Fixed(n) => PbParallelism::Fixed(PbFixedParallelism {
86 parallelism: n as u32,
87 }),
88 Custom => PbParallelism::Custom(PbCustomParallelism {}),
89 };
90
91 Self {
92 parallelism: Some(parallelism),
93 }
94 }
95}
96
97impl From<StreamingParallelism> for TableParallelism {
98 fn from(value: StreamingParallelism) -> Self {
99 match value {
100 StreamingParallelism::Adaptive => TableParallelism::Adaptive,
101 StreamingParallelism::Fixed(n) => TableParallelism::Fixed(n),
102 StreamingParallelism::Custom => TableParallelism::Custom,
103 }
104 }
105}
106
107impl From<TableParallelism> for StreamingParallelism {
108 fn from(value: TableParallelism) -> Self {
109 match value {
110 TableParallelism::Adaptive => StreamingParallelism::Adaptive,
111 TableParallelism::Fixed(n) => StreamingParallelism::Fixed(n),
112 TableParallelism::Custom => StreamingParallelism::Custom,
113 }
114 }
115}
116
117pub type ActorUpstreams = BTreeMap<FragmentId, HashMap<ActorId, ActorInfo>>;
118pub type StreamActorWithDispatchers = (StreamActor, Vec<PbDispatcher>);
119pub type StreamActorWithUpDownstreams = (StreamActor, ActorUpstreams, Vec<PbDispatcher>);
120pub type FragmentActorDispatchers = HashMap<FragmentId, HashMap<ActorId, Vec<PbDispatcher>>>;
121
122pub type FragmentDownstreamRelation = HashMap<FragmentId, Vec<DownstreamFragmentRelation>>;
123pub type FragmentReplaceUpstream = HashMap<FragmentId, HashMap<FragmentId, FragmentId>>;
125pub type ActorNewNoShuffle = HashMap<FragmentId, HashMap<FragmentId, HashMap<ActorId, ActorId>>>;
128
129#[derive(Debug, Clone)]
130pub struct DownstreamFragmentRelation {
131 pub downstream_fragment_id: FragmentId,
132 pub dispatcher_type: DispatcherType,
133 pub dist_key_indices: Vec<u32>,
134 pub output_mapping: PbDispatchOutputMapping,
135}
136
137impl From<(FragmentId, DispatchStrategy)> for DownstreamFragmentRelation {
138 fn from((fragment_id, dispatch): (FragmentId, DispatchStrategy)) -> Self {
139 Self {
140 downstream_fragment_id: fragment_id,
141 dispatcher_type: dispatch.get_type().unwrap().into(),
142 dist_key_indices: dispatch.dist_key_indices,
143 output_mapping: dispatch.output_mapping.unwrap(),
144 }
145 }
146}
147
148#[derive(Debug, Clone)]
149pub struct StreamJobFragmentsToCreate {
150 pub inner: StreamJobFragments,
151 pub downstreams: FragmentDownstreamRelation,
152}
153
154impl Deref for StreamJobFragmentsToCreate {
155 type Target = StreamJobFragments;
156
157 fn deref(&self) -> &Self::Target {
158 &self.inner
159 }
160}
161
162#[derive(Clone, Debug)]
163pub struct StreamActor {
164 pub actor_id: ActorId,
165 pub fragment_id: FragmentId,
166 pub vnode_bitmap: Option<Bitmap>,
167 pub mview_definition: String,
168 pub expr_context: Option<PbExprContext>,
169 pub config_override: Arc<str>,
171}
172
173impl StreamActor {
174 fn to_protobuf(&self, dispatchers: impl Iterator<Item = Dispatcher>) -> PbStreamActor {
175 PbStreamActor {
176 actor_id: self.actor_id,
177 fragment_id: self.fragment_id,
178 dispatcher: dispatchers.collect(),
179 vnode_bitmap: self
180 .vnode_bitmap
181 .as_ref()
182 .map(|bitmap| bitmap.to_protobuf()),
183 mview_definition: self.mview_definition.clone(),
184 expr_context: self.expr_context.clone(),
185 config_override: self.config_override.to_string(),
186 }
187 }
188}
189
190#[derive(Clone, Debug, Default)]
191pub struct Fragment {
192 pub fragment_id: FragmentId,
193 pub fragment_type_mask: FragmentTypeMask,
194 pub distribution_type: PbFragmentDistributionType,
195 pub state_table_ids: Vec<TableId>,
196 pub maybe_vnode_count: Option<u32>,
197 pub nodes: StreamNode,
198}
199
200impl Fragment {
201 pub fn to_protobuf(
202 &self,
203 actors: &[StreamActor],
204 upstream_fragments: impl Iterator<Item = FragmentId>,
205 dispatchers: Option<&HashMap<ActorId, Vec<Dispatcher>>>,
206 ) -> PbFragment {
207 PbFragment {
208 fragment_id: self.fragment_id,
209 fragment_type_mask: self.fragment_type_mask.into(),
210 distribution_type: self.distribution_type as _,
211 actors: actors
212 .iter()
213 .map(|actor| {
214 actor.to_protobuf(
215 dispatchers
216 .and_then(|dispatchers| dispatchers.get(&actor.actor_id))
217 .into_iter()
218 .flatten()
219 .cloned(),
220 )
221 })
222 .collect(),
223 state_table_ids: self.state_table_ids.clone(),
224 upstream_fragment_ids: upstream_fragments.collect(),
225 maybe_vnode_count: self.maybe_vnode_count,
226 nodes: Some(self.nodes.clone()),
227 }
228 }
229}
230
231impl VnodeCountCompat for Fragment {
232 fn vnode_count_inner(&self) -> VnodeCount {
233 VnodeCount::from_protobuf(self.maybe_vnode_count, || self.is_singleton())
234 }
235}
236
237impl IsSingleton for Fragment {
238 fn is_singleton(&self) -> bool {
239 matches!(self.distribution_type, FragmentDistributionType::Single)
240 }
241}
242
243impl From<fragment::Model> for Fragment {
244 fn from(model: fragment::Model) -> Self {
245 Self {
246 fragment_id: model.fragment_id,
247 fragment_type_mask: FragmentTypeMask::from(model.fragment_type_mask),
248 distribution_type: model.distribution_type.into(),
249 state_table_ids: model.state_table_ids.into_inner(),
250 maybe_vnode_count: VnodeCount::set(model.vnode_count).to_protobuf(),
251 nodes: model.stream_node.to_protobuf(),
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
262pub struct StreamJobFragments {
263 pub stream_job_id: JobId,
265
266 pub state: State,
268
269 pub fragments: BTreeMap<FragmentId, Fragment>,
271
272 pub ctx: StreamContext,
274
275 pub assigned_parallelism: TableParallelism,
277
278 pub max_parallelism: usize,
289}
290
291#[derive(Debug, Clone, Default)]
292pub struct StreamContext {
293 pub timezone: Option<String>,
295
296 pub config_override: Arc<str>,
298
299 pub adaptive_parallelism_strategy: Option<AdaptiveParallelismStrategy>,
301}
302
303impl StreamContext {
304 pub fn to_protobuf(&self) -> PbStreamContext {
305 PbStreamContext {
306 timezone: self.timezone.clone().unwrap_or("".into()),
307 config_override: self.config_override.to_string(),
308 adaptive_parallelism_strategy: self
309 .adaptive_parallelism_strategy
310 .as_ref()
311 .map(ToString::to_string)
312 .unwrap_or_default(),
313 backfill_adaptive_parallelism_strategy: String::new(),
314 }
315 }
316
317 pub fn to_expr_context(&self) -> PbExprContext {
318 PbExprContext {
319 time_zone: self.timezone.clone().unwrap_or("Empty Time Zone".into()),
321 strict_mode: false,
322 }
323 }
324
325 pub fn from_protobuf(prost: &PbStreamContext) -> Self {
326 Self {
327 timezone: if prost.get_timezone().is_empty() {
328 None
329 } else {
330 Some(prost.get_timezone().clone())
331 },
332 config_override: prost.get_config_override().as_str().into(),
333 adaptive_parallelism_strategy: if prost.get_adaptive_parallelism_strategy().is_empty() {
334 None
335 } else {
336 Some(
337 parse_strategy(prost.get_adaptive_parallelism_strategy())
338 .expect("adaptive parallelism strategy should be validated in frontend"),
339 )
340 },
341 }
342 }
343}
344
345#[easy_ext::ext(StreamingJobModelContextExt)]
346impl risingwave_meta_model::streaming_job::Model {
347 pub fn stream_context(&self) -> StreamContext {
348 StreamContext {
349 timezone: self.timezone.clone(),
350 config_override: self.config_override.clone().unwrap_or_default().into(),
351 adaptive_parallelism_strategy: self.adaptive_parallelism_strategy.as_deref().map(|s| {
352 parse_strategy(s).expect("strategy should be validated before persisting")
353 }),
354 }
355 }
356}
357
358impl StreamJobFragments {
359 pub fn to_protobuf(
360 &self,
361 fragment_actors: &HashMap<FragmentId, Vec<StreamActor>>,
362 fragment_upstreams: &HashMap<FragmentId, HashSet<FragmentId>>,
363 fragment_dispatchers: &FragmentActorDispatchers,
364 actor_status: HashMap<ActorId, PbActorStatus>,
365 ) -> PbTableFragments {
366 PbTableFragments {
367 table_id: self.stream_job_id,
368 state: self.state as _,
369 fragments: self
370 .fragments
371 .iter()
372 .map(|(id, fragment)| {
373 let actors = fragment_actors.get(id).map(|a| a.as_slice()).unwrap_or(&[]);
374 (
375 *id,
376 fragment.to_protobuf(
377 actors,
378 fragment_upstreams.get(id).into_iter().flatten().cloned(),
379 fragment_dispatchers.get(id),
380 ),
381 )
382 })
383 .collect(),
384 actor_status,
385 ctx: Some(self.ctx.to_protobuf()),
386 parallelism: Some(self.assigned_parallelism.into()),
387 node_label: "".to_owned(),
388 backfill_done: true,
389 max_parallelism: Some(self.max_parallelism as _),
390 }
391 }
392}
393
394pub type StreamJobActorsToCreate = HashMap<
395 WorkerId,
396 HashMap<
397 FragmentId,
398 (
399 StreamNode,
400 Vec<StreamActorWithUpDownstreams>,
401 HashSet<SubscriberId>,
402 ),
403 >,
404>;
405
406impl StreamJobFragments {
407 pub fn for_test(job_id: JobId, fragments: BTreeMap<FragmentId, Fragment>) -> Self {
409 Self::new(
410 job_id,
411 fragments,
412 StreamContext::default(),
413 TableParallelism::Adaptive,
414 VirtualNode::COUNT_FOR_TEST,
415 )
416 }
417
418 pub fn new(
420 stream_job_id: JobId,
421 fragments: BTreeMap<FragmentId, Fragment>,
422 ctx: StreamContext,
423 table_parallelism: TableParallelism,
424 max_parallelism: usize,
425 ) -> Self {
426 Self {
427 stream_job_id,
428 state: State::Initial,
429 fragments,
430 ctx,
431 assigned_parallelism: table_parallelism,
432 max_parallelism,
433 }
434 }
435
436 pub fn fragment_ids(&self) -> impl Iterator<Item = FragmentId> + '_ {
437 self.fragments.keys().cloned()
438 }
439
440 pub fn fragments(&self) -> impl Iterator<Item = &Fragment> {
441 self.fragments.values()
442 }
443
444 pub fn stream_job_id(&self) -> JobId {
446 self.stream_job_id
447 }
448
449 pub fn timezone(&self) -> Option<String> {
451 self.ctx.timezone.clone()
452 }
453
454 pub fn is_created(&self) -> bool {
456 self.state == State::Created
457 }
458
459 #[cfg(test)]
461 pub fn mview_fragment_ids(&self) -> Vec<FragmentId> {
462 self.fragments
463 .values()
464 .filter(move |fragment| {
465 fragment
466 .fragment_type_mask
467 .contains(FragmentTypeFlag::Mview)
468 })
469 .map(|fragment| fragment.fragment_id)
470 .collect()
471 }
472
473 pub fn tracking_progress_actor_ids_impl(
475 fragments: impl IntoIterator<Item = (FragmentTypeMask, impl Iterator<Item = ActorId>)>,
476 ) -> Vec<(ActorId, BackfillUpstreamType)> {
477 let mut actor_ids = vec![];
478 for (fragment_type_mask, actors) in fragments {
479 if fragment_type_mask.contains(FragmentTypeFlag::CdcFilter) {
480 return vec![];
483 }
484 if fragment_type_mask.contains_any([
485 FragmentTypeFlag::Values,
486 FragmentTypeFlag::StreamScan,
487 FragmentTypeFlag::SourceScan,
488 FragmentTypeFlag::LocalityProvider,
489 ]) {
490 actor_ids.extend(actors.map(|actor_id| {
491 (
492 actor_id,
493 BackfillUpstreamType::from_fragment_type_mask(fragment_type_mask),
494 )
495 }));
496 }
497 }
498 actor_ids
499 }
500
501 pub fn root_fragment(&self) -> Option<Fragment> {
502 self.mview_fragment()
503 .or_else(|| self.sink_fragment())
504 .or_else(|| self.source_fragment())
505 }
506
507 pub fn mview_fragment(&self) -> Option<Fragment> {
509 self.fragments
510 .values()
511 .find(|fragment| {
512 fragment
513 .fragment_type_mask
514 .contains(FragmentTypeFlag::Mview)
515 })
516 .cloned()
517 }
518
519 pub fn source_fragment(&self) -> Option<Fragment> {
520 self.fragments
521 .values()
522 .find(|fragment| {
523 fragment
524 .fragment_type_mask
525 .contains(FragmentTypeFlag::Source)
526 })
527 .cloned()
528 }
529
530 pub fn sink_fragment(&self) -> Option<Fragment> {
531 self.fragments
532 .values()
533 .find(|fragment| fragment.fragment_type_mask.contains(FragmentTypeFlag::Sink))
534 .cloned()
535 }
536
537 pub fn stream_source_fragments(&self) -> HashMap<SourceId, BTreeSet<FragmentId>> {
540 let mut source_fragments = HashMap::new();
541
542 for fragment in self.fragments() {
543 {
544 if let Some(source_id) = fragment.nodes.find_stream_source() {
545 source_fragments
546 .entry(source_id)
547 .or_insert(BTreeSet::new())
548 .insert(fragment.fragment_id as FragmentId);
549 }
550 }
551 }
552 source_fragments
553 }
554
555 pub fn source_backfill_fragments(
556 &self,
557 ) -> HashMap<SourceId, BTreeSet<(FragmentId, FragmentId)>> {
558 Self::source_backfill_fragments_impl(
559 self.fragments
560 .iter()
561 .map(|(fragment_id, fragment)| (*fragment_id, &fragment.nodes)),
562 )
563 }
564
565 pub fn source_backfill_fragments_impl(
570 fragments: impl Iterator<Item = (FragmentId, &StreamNode)>,
571 ) -> HashMap<SourceId, BTreeSet<(FragmentId, FragmentId)>> {
572 let mut source_backfill_fragments = HashMap::new();
573
574 for (fragment_id, fragment_node) in fragments {
575 {
576 if let Some((source_id, upstream_source_fragment_id)) =
577 fragment_node.find_source_backfill()
578 {
579 source_backfill_fragments
580 .entry(source_id)
581 .or_insert(BTreeSet::new())
582 .insert((fragment_id, upstream_source_fragment_id));
583 }
584 }
585 }
586 source_backfill_fragments
587 }
588
589 pub fn union_fragment_for_table(&mut self) -> &mut Fragment {
592 let mut union_fragment_id = None;
593 for (fragment_id, fragment) in &self.fragments {
594 {
595 {
596 visit_stream_node_body(&fragment.nodes, |body| {
597 if let NodeBody::Union(_) = body {
598 if let Some(union_fragment_id) = union_fragment_id.as_mut() {
599 assert_eq!(*union_fragment_id, *fragment_id);
601 } else {
602 union_fragment_id = Some(*fragment_id);
603 }
604 }
605 })
606 }
607 }
608 }
609
610 let union_fragment_id =
611 union_fragment_id.expect("fragment of placeholder merger not found");
612
613 (self
614 .fragments
615 .get_mut(&union_fragment_id)
616 .unwrap_or_else(|| panic!("fragment {} not found", union_fragment_id))) as _
617 }
618
619 fn resolve_dependent_table(stream_node: &StreamNode, table_ids: &mut HashMap<TableId, usize>) {
621 let table_id = match stream_node.node_body.as_ref() {
622 Some(NodeBody::StreamScan(stream_scan)) => Some(stream_scan.table_id),
623 Some(NodeBody::StreamCdcScan(stream_scan)) => Some(stream_scan.table_id),
624 Some(NodeBody::LocalityProvider(state)) => {
625 Some(state.state_table.as_ref().expect("must have state").id)
626 }
627 _ => None,
628 };
629 if let Some(table_id) = table_id {
630 table_ids.entry(table_id).or_default().add_assign(1);
631 }
632
633 for child in &stream_node.input {
634 Self::resolve_dependent_table(child, table_ids);
635 }
636 }
637
638 pub fn upstream_table_counts(&self) -> HashMap<TableId, usize> {
639 Self::upstream_table_counts_impl(self.fragments.values().map(|fragment| &fragment.nodes))
640 }
641
642 pub fn upstream_table_counts_impl(
644 fragment_nodes: impl Iterator<Item = &StreamNode>,
645 ) -> HashMap<TableId, usize> {
646 let mut table_ids = HashMap::new();
647 fragment_nodes.for_each(|node| {
648 Self::resolve_dependent_table(node, &mut table_ids);
649 });
650
651 table_ids
652 }
653
654 pub fn mv_table_id(&self) -> Option<TableId> {
655 self.fragments
656 .values()
657 .flat_map(|f| f.state_table_ids.iter().copied())
658 .find(|table_id| self.stream_job_id.is_mv_table_id(*table_id))
659 }
660
661 pub fn collect_tables(fragments: impl Iterator<Item = &Fragment>) -> BTreeMap<TableId, Table> {
662 let mut tables = BTreeMap::new();
663 for fragment in fragments {
664 stream_graph_visitor::visit_stream_node_tables_inner(
665 &mut fragment.nodes.clone(),
666 false,
667 true,
668 |table, _| {
669 let table_id = table.id;
670 tables
671 .try_insert(table_id, table.clone())
672 .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
673 },
674 );
675 }
676 tables
677 }
678
679 pub fn internal_table_ids(&self) -> Vec<TableId> {
681 self.fragments
682 .values()
683 .flat_map(|f| f.state_table_ids.iter().copied())
684 .filter(|&t| !self.stream_job_id.is_mv_table_id(t))
685 .collect_vec()
686 }
687
688 pub fn all_table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
690 self.fragments
691 .values()
692 .flat_map(|f| f.state_table_ids.clone())
693 }
694}
695
696#[derive(Debug, Display, Clone, Copy, PartialEq, Eq)]
697pub enum BackfillUpstreamType {
698 MView,
699 Values,
700 Source,
701 LocalityProvider,
702}
703
704impl BackfillUpstreamType {
705 pub fn from_fragment_type_mask(mask: FragmentTypeMask) -> Self {
706 let is_mview = mask.contains(FragmentTypeFlag::StreamScan);
707 let is_values = mask.contains(FragmentTypeFlag::Values);
708 let is_source = mask.contains(FragmentTypeFlag::SourceScan);
709 let is_locality_provider = mask.contains(FragmentTypeFlag::LocalityProvider);
710
711 debug_assert!(
714 is_mview as u8 + is_values as u8 + is_source as u8 + is_locality_provider as u8 == 1,
715 "a backfill fragment should either be mview, value, source, or locality provider, found {:?}",
716 mask
717 );
718
719 if is_mview {
720 BackfillUpstreamType::MView
721 } else if is_values {
722 BackfillUpstreamType::Values
723 } else if is_source {
724 BackfillUpstreamType::Source
725 } else if is_locality_provider {
726 BackfillUpstreamType::LocalityProvider
727 } else {
728 unreachable!("invalid fragment type mask: {:?}", mask);
729 }
730 }
731}