1use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
16use std::ops::{AddAssign, Deref};
17use std::sync::Arc;
18
19use itertools::Itertools;
20use risingwave_common::bitmap::Bitmap;
21use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask, TableId};
22use risingwave_common::hash::{IsSingleton, VirtualNode, VnodeCount, VnodeCountCompat};
23use risingwave_common::id::JobId;
24use risingwave_common::util::stream_graph_visitor::{self, visit_stream_node_body};
25use risingwave_meta_model::{DispatcherType, SourceId, StreamingParallelism, WorkerId, fragment};
26use risingwave_pb::catalog::Table;
27use risingwave_pb::common::ActorInfo;
28use risingwave_pb::id::SubscriberId;
29use risingwave_pb::meta::table_fragments::fragment::{
30 FragmentDistributionType, PbFragmentDistributionType,
31};
32use risingwave_pb::meta::table_fragments::{PbActorStatus, PbFragment, State};
33use risingwave_pb::meta::table_parallelism::{
34 FixedParallelism, Parallelism, PbAdaptiveParallelism, PbCustomParallelism, PbFixedParallelism,
35 PbParallelism,
36};
37use risingwave_pb::meta::{PbTableFragments, PbTableParallelism};
38use risingwave_pb::plan_common::PbExprContext;
39use risingwave_pb::stream_plan::stream_node::NodeBody;
40use risingwave_pb::stream_plan::{
41 DispatchStrategy, Dispatcher, PbDispatchOutputMapping, PbDispatcher, PbStreamActor,
42 PbStreamContext, StreamNode,
43};
44use strum::Display;
45
46use super::{ActorId, FragmentId};
47
48#[derive(Debug, Copy, Clone, Eq, PartialEq)]
50pub enum TableParallelism {
51 Adaptive,
53 Fixed(usize),
56 Custom,
63}
64
65impl From<PbTableParallelism> for TableParallelism {
66 fn from(value: PbTableParallelism) -> Self {
67 use Parallelism::*;
68 match &value.parallelism {
69 Some(Fixed(FixedParallelism { parallelism: n })) => Self::Fixed(*n as usize),
70 Some(Adaptive(_)) | Some(Auto(_)) => Self::Adaptive,
71 Some(Custom(_)) => Self::Custom,
72 _ => unreachable!(),
73 }
74 }
75}
76
77impl From<TableParallelism> for PbTableParallelism {
78 fn from(value: TableParallelism) -> Self {
79 use TableParallelism::*;
80
81 let parallelism = match value {
82 Adaptive => PbParallelism::Adaptive(PbAdaptiveParallelism {}),
83 Fixed(n) => PbParallelism::Fixed(PbFixedParallelism {
84 parallelism: n as u32,
85 }),
86 Custom => PbParallelism::Custom(PbCustomParallelism {}),
87 };
88
89 Self {
90 parallelism: Some(parallelism),
91 }
92 }
93}
94
95impl From<StreamingParallelism> for TableParallelism {
96 fn from(value: StreamingParallelism) -> Self {
97 match value {
98 StreamingParallelism::Adaptive => TableParallelism::Adaptive,
99 StreamingParallelism::Fixed(n) => TableParallelism::Fixed(n),
100 StreamingParallelism::Custom => TableParallelism::Custom,
101 }
102 }
103}
104
105impl From<TableParallelism> for StreamingParallelism {
106 fn from(value: TableParallelism) -> Self {
107 match value {
108 TableParallelism::Adaptive => StreamingParallelism::Adaptive,
109 TableParallelism::Fixed(n) => StreamingParallelism::Fixed(n),
110 TableParallelism::Custom => StreamingParallelism::Custom,
111 }
112 }
113}
114
115pub type ActorUpstreams = BTreeMap<FragmentId, HashMap<ActorId, ActorInfo>>;
116pub type StreamActorWithDispatchers = (StreamActor, Vec<PbDispatcher>);
117pub type StreamActorWithUpDownstreams = (StreamActor, ActorUpstreams, Vec<PbDispatcher>);
118pub type FragmentActorDispatchers = HashMap<FragmentId, HashMap<ActorId, Vec<PbDispatcher>>>;
119
120pub type FragmentDownstreamRelation = HashMap<FragmentId, Vec<DownstreamFragmentRelation>>;
121pub type FragmentReplaceUpstream = HashMap<FragmentId, HashMap<FragmentId, FragmentId>>;
123pub type ActorNewNoShuffle = HashMap<FragmentId, HashMap<FragmentId, HashMap<ActorId, ActorId>>>;
126
127#[derive(Debug, Clone)]
128pub struct DownstreamFragmentRelation {
129 pub downstream_fragment_id: FragmentId,
130 pub dispatcher_type: DispatcherType,
131 pub dist_key_indices: Vec<u32>,
132 pub output_mapping: PbDispatchOutputMapping,
133}
134
135impl From<(FragmentId, DispatchStrategy)> for DownstreamFragmentRelation {
136 fn from((fragment_id, dispatch): (FragmentId, DispatchStrategy)) -> Self {
137 Self {
138 downstream_fragment_id: fragment_id,
139 dispatcher_type: dispatch.get_type().unwrap().into(),
140 dist_key_indices: dispatch.dist_key_indices,
141 output_mapping: dispatch.output_mapping.unwrap(),
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
147pub struct StreamJobFragmentsToCreate {
148 pub inner: StreamJobFragments,
149 pub downstreams: FragmentDownstreamRelation,
150}
151
152impl Deref for StreamJobFragmentsToCreate {
153 type Target = StreamJobFragments;
154
155 fn deref(&self) -> &Self::Target {
156 &self.inner
157 }
158}
159
160#[derive(Clone, Debug)]
161pub struct StreamActor {
162 pub actor_id: ActorId,
163 pub fragment_id: FragmentId,
164 pub vnode_bitmap: Option<Bitmap>,
165 pub mview_definition: String,
166 pub expr_context: Option<PbExprContext>,
167 pub config_override: Arc<str>,
169}
170
171impl StreamActor {
172 fn to_protobuf(&self, dispatchers: impl Iterator<Item = Dispatcher>) -> PbStreamActor {
173 PbStreamActor {
174 actor_id: self.actor_id,
175 fragment_id: self.fragment_id,
176 dispatcher: dispatchers.collect(),
177 vnode_bitmap: self
178 .vnode_bitmap
179 .as_ref()
180 .map(|bitmap| bitmap.to_protobuf()),
181 mview_definition: self.mview_definition.clone(),
182 expr_context: self.expr_context.clone(),
183 config_override: self.config_override.to_string(),
184 }
185 }
186}
187
188#[derive(Clone, Debug, Default)]
189pub struct Fragment {
190 pub fragment_id: FragmentId,
191 pub fragment_type_mask: FragmentTypeMask,
192 pub distribution_type: PbFragmentDistributionType,
193 pub state_table_ids: Vec<TableId>,
194 pub maybe_vnode_count: Option<u32>,
195 pub nodes: StreamNode,
196}
197
198impl Fragment {
199 pub fn to_protobuf(
200 &self,
201 actors: &[StreamActor],
202 upstream_fragments: impl Iterator<Item = FragmentId>,
203 dispatchers: Option<&HashMap<ActorId, Vec<Dispatcher>>>,
204 ) -> PbFragment {
205 PbFragment {
206 fragment_id: self.fragment_id,
207 fragment_type_mask: self.fragment_type_mask.into(),
208 distribution_type: self.distribution_type as _,
209 actors: actors
210 .iter()
211 .map(|actor| {
212 actor.to_protobuf(
213 dispatchers
214 .and_then(|dispatchers| dispatchers.get(&actor.actor_id))
215 .into_iter()
216 .flatten()
217 .cloned(),
218 )
219 })
220 .collect(),
221 state_table_ids: self.state_table_ids.clone(),
222 upstream_fragment_ids: upstream_fragments.collect(),
223 maybe_vnode_count: self.maybe_vnode_count,
224 nodes: Some(self.nodes.clone()),
225 }
226 }
227}
228
229impl VnodeCountCompat for Fragment {
230 fn vnode_count_inner(&self) -> VnodeCount {
231 VnodeCount::from_protobuf(self.maybe_vnode_count, || self.is_singleton())
232 }
233}
234
235impl IsSingleton for Fragment {
236 fn is_singleton(&self) -> bool {
237 matches!(self.distribution_type, FragmentDistributionType::Single)
238 }
239}
240
241impl From<fragment::Model> for Fragment {
242 fn from(model: fragment::Model) -> Self {
243 Self {
244 fragment_id: model.fragment_id,
245 fragment_type_mask: FragmentTypeMask::from(model.fragment_type_mask),
246 distribution_type: model.distribution_type.into(),
247 state_table_ids: model.state_table_ids.into_inner(),
248 maybe_vnode_count: VnodeCount::set(model.vnode_count).to_protobuf(),
249 nodes: model.stream_node.to_protobuf(),
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
260pub struct StreamJobFragments {
261 pub stream_job_id: JobId,
263
264 pub state: State,
266
267 pub fragments: BTreeMap<FragmentId, Fragment>,
269
270 pub ctx: StreamContext,
272
273 pub max_parallelism: usize,
284}
285
286#[derive(Debug, Clone, Default)]
287pub struct StreamContext {
288 pub timezone: Option<String>,
290
291 pub config_override: Arc<str>,
293}
294
295impl StreamContext {
296 pub fn to_protobuf(&self) -> PbStreamContext {
297 PbStreamContext {
298 timezone: self.timezone.clone().unwrap_or("".into()),
299 config_override: self.config_override.to_string(),
300 }
301 }
302
303 pub fn to_expr_context(&self) -> PbExprContext {
304 PbExprContext {
305 time_zone: self.timezone.clone().unwrap_or("Empty Time Zone".into()),
307 strict_mode: false,
308 }
309 }
310
311 pub fn from_protobuf(prost: &PbStreamContext) -> Self {
312 Self {
313 timezone: if prost.get_timezone().is_empty() {
314 None
315 } else {
316 Some(prost.get_timezone().clone())
317 },
318 config_override: prost.get_config_override().as_str().into(),
319 }
320 }
321}
322
323#[easy_ext::ext(StreamingJobModelContextExt)]
324impl risingwave_meta_model::streaming_job::Model {
325 pub fn stream_context(&self) -> StreamContext {
326 StreamContext {
327 timezone: self.timezone.clone(),
328 config_override: self.config_override.clone().unwrap_or_default().into(),
329 }
330 }
331}
332
333impl StreamJobFragments {
334 pub fn to_protobuf(
335 &self,
336 fragment_actors: &HashMap<FragmentId, Vec<StreamActor>>,
337 fragment_upstreams: &HashMap<FragmentId, HashSet<FragmentId>>,
338 fragment_dispatchers: &FragmentActorDispatchers,
339 actor_status: HashMap<ActorId, PbActorStatus>,
340 ) -> PbTableFragments {
341 PbTableFragments {
342 table_id: self.stream_job_id,
343 state: self.state as _,
344 fragments: self
345 .fragments
346 .iter()
347 .map(|(id, fragment)| {
348 let actors = fragment_actors.get(id).map(|a| a.as_slice()).unwrap_or(&[]);
349 (
350 *id,
351 fragment.to_protobuf(
352 actors,
353 fragment_upstreams.get(id).into_iter().flatten().cloned(),
354 fragment_dispatchers.get(id),
355 ),
356 )
357 })
358 .collect(),
359 actor_status,
360 ctx: Some(self.ctx.to_protobuf()),
361 node_label: "".to_owned(),
362 backfill_done: true,
363 max_parallelism: Some(self.max_parallelism as _),
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_stream_context_protobuf_round_trip() {
374 let context = StreamContext {
375 timezone: Some("UTC".to_owned()),
376 config_override: "{\"a\":1}".into(),
377 };
378
379 let prost = context.to_protobuf();
380
381 let round_trip = StreamContext::from_protobuf(&prost);
382 assert_eq!(round_trip.timezone, Some("UTC".to_owned()));
383 assert_eq!(&*round_trip.config_override, "{\"a\":1}");
384 }
385
386 #[test]
387 fn test_stream_context_from_model() {
388 let model = risingwave_meta_model::streaming_job::Model {
389 job_id: 1.into(),
390 job_status: risingwave_meta_model::JobStatus::Created,
391 create_type: risingwave_meta_model::CreateType::Foreground,
392 timezone: Some("Asia/Shanghai".to_owned()),
393 config_override: Some("{\"parallelism\":2}".to_owned()),
394 adaptive_parallelism_strategy: Some("AUTO".to_owned()),
395 parallelism: StreamingParallelism::Adaptive,
396 backfill_parallelism: Some(StreamingParallelism::Adaptive),
397 backfill_adaptive_parallelism_strategy: Some("RATIO(0.25)".to_owned()),
398 backfill_orders: None,
399 max_parallelism: 32,
400 specific_resource_group: None,
401 is_serverless_backfill: false,
402 };
403
404 let context = model.stream_context();
405 assert_eq!(context.timezone, Some("Asia/Shanghai".to_owned()));
406 assert_eq!(&*context.config_override, "{\"parallelism\":2}");
407 }
408}
409
410pub type StreamJobActorsToCreate = HashMap<
411 WorkerId,
412 HashMap<
413 FragmentId,
414 (
415 StreamNode,
416 Vec<StreamActorWithUpDownstreams>,
417 HashSet<SubscriberId>,
418 ),
419 >,
420>;
421
422impl StreamJobFragments {
423 pub fn for_test(job_id: JobId, fragments: BTreeMap<FragmentId, Fragment>) -> Self {
425 Self::new(
426 job_id,
427 fragments,
428 StreamContext::default(),
429 VirtualNode::COUNT_FOR_TEST,
430 )
431 }
432
433 pub fn new(
435 stream_job_id: JobId,
436 fragments: BTreeMap<FragmentId, Fragment>,
437 ctx: StreamContext,
438 max_parallelism: usize,
439 ) -> Self {
440 Self {
441 stream_job_id,
442 state: State::Initial,
443 fragments,
444 ctx,
445 max_parallelism,
446 }
447 }
448
449 pub fn fragment_ids(&self) -> impl Iterator<Item = FragmentId> + '_ {
450 self.fragments.keys().cloned()
451 }
452
453 pub fn fragments(&self) -> impl Iterator<Item = &Fragment> {
454 self.fragments.values()
455 }
456
457 pub fn stream_job_id(&self) -> JobId {
459 self.stream_job_id
460 }
461
462 pub fn timezone(&self) -> Option<String> {
464 self.ctx.timezone.clone()
465 }
466
467 pub fn is_created(&self) -> bool {
469 self.state == State::Created
470 }
471
472 #[cfg(test)]
474 pub fn mview_fragment_ids(&self) -> Vec<FragmentId> {
475 self.fragments
476 .values()
477 .filter(move |fragment| {
478 fragment
479 .fragment_type_mask
480 .contains(FragmentTypeFlag::Mview)
481 })
482 .map(|fragment| fragment.fragment_id)
483 .collect()
484 }
485
486 pub fn tracking_progress_actor_ids_impl(
488 fragments: impl IntoIterator<Item = (FragmentTypeMask, impl Iterator<Item = ActorId>)>,
489 ) -> Vec<(ActorId, BackfillUpstreamType)> {
490 let mut actor_ids = vec![];
491 for (fragment_type_mask, actors) in fragments {
492 if fragment_type_mask.contains(FragmentTypeFlag::CdcFilter) {
493 return vec![];
496 }
497 if fragment_type_mask.contains_any([
498 FragmentTypeFlag::Values,
499 FragmentTypeFlag::StreamScan,
500 FragmentTypeFlag::SourceScan,
501 FragmentTypeFlag::LocalityProvider,
502 ]) {
503 actor_ids.extend(actors.map(|actor_id| {
504 (
505 actor_id,
506 BackfillUpstreamType::from_fragment_type_mask(fragment_type_mask),
507 )
508 }));
509 }
510 }
511 actor_ids
512 }
513
514 pub fn root_fragment(&self) -> Option<Fragment> {
515 self.mview_fragment()
516 .or_else(|| self.sink_fragment())
517 .or_else(|| self.source_fragment())
518 }
519
520 pub fn mview_fragment(&self) -> Option<Fragment> {
522 self.fragments
523 .values()
524 .find(|fragment| {
525 fragment
526 .fragment_type_mask
527 .contains(FragmentTypeFlag::Mview)
528 })
529 .cloned()
530 }
531
532 pub fn source_fragment(&self) -> Option<Fragment> {
533 self.fragments
534 .values()
535 .find(|fragment| {
536 fragment
537 .fragment_type_mask
538 .contains(FragmentTypeFlag::Source)
539 })
540 .cloned()
541 }
542
543 pub fn sink_fragment(&self) -> Option<Fragment> {
544 self.fragments
545 .values()
546 .find(|fragment| fragment.fragment_type_mask.contains(FragmentTypeFlag::Sink))
547 .cloned()
548 }
549
550 pub fn stream_source_fragments(&self) -> HashMap<SourceId, BTreeSet<FragmentId>> {
553 let mut source_fragments = HashMap::new();
554
555 for fragment in self.fragments() {
556 {
557 if let Some(source_id) = fragment.nodes.find_stream_source() {
558 source_fragments
559 .entry(source_id)
560 .or_insert(BTreeSet::new())
561 .insert(fragment.fragment_id as FragmentId);
562 }
563 }
564 }
565 source_fragments
566 }
567
568 pub fn source_backfill_fragments(
569 &self,
570 ) -> HashMap<SourceId, BTreeSet<(FragmentId, FragmentId)>> {
571 Self::source_backfill_fragments_impl(
572 self.fragments
573 .iter()
574 .map(|(fragment_id, fragment)| (*fragment_id, &fragment.nodes)),
575 )
576 }
577
578 pub fn source_backfill_fragments_impl(
583 fragments: impl Iterator<Item = (FragmentId, &StreamNode)>,
584 ) -> HashMap<SourceId, BTreeSet<(FragmentId, FragmentId)>> {
585 let mut source_backfill_fragments = HashMap::new();
586
587 for (fragment_id, fragment_node) in fragments {
588 {
589 if let Some((source_id, upstream_source_fragment_id)) =
590 fragment_node.find_source_backfill()
591 {
592 source_backfill_fragments
593 .entry(source_id)
594 .or_insert(BTreeSet::new())
595 .insert((fragment_id, upstream_source_fragment_id));
596 }
597 }
598 }
599 source_backfill_fragments
600 }
601
602 pub fn union_fragment_for_table(&mut self) -> &mut Fragment {
605 let mut union_fragment_id = None;
606 for (fragment_id, fragment) in &self.fragments {
607 {
608 {
609 visit_stream_node_body(&fragment.nodes, |body| {
610 if let NodeBody::Union(_) = body {
611 if let Some(union_fragment_id) = union_fragment_id.as_mut() {
612 assert_eq!(*union_fragment_id, *fragment_id);
614 } else {
615 union_fragment_id = Some(*fragment_id);
616 }
617 }
618 })
619 }
620 }
621 }
622
623 let union_fragment_id =
624 union_fragment_id.expect("fragment of placeholder merger not found");
625
626 (self
627 .fragments
628 .get_mut(&union_fragment_id)
629 .unwrap_or_else(|| panic!("fragment {} not found", union_fragment_id))) as _
630 }
631
632 fn resolve_dependent_table(stream_node: &StreamNode, table_ids: &mut HashMap<TableId, usize>) {
634 let table_id = match stream_node.node_body.as_ref() {
635 Some(NodeBody::StreamScan(stream_scan)) => Some(stream_scan.table_id),
636 Some(NodeBody::StreamCdcScan(stream_scan)) => Some(stream_scan.table_id),
637 Some(NodeBody::LocalityProvider(state)) => {
638 Some(state.state_table.as_ref().expect("must have state").id)
639 }
640 _ => None,
641 };
642 if let Some(table_id) = table_id {
643 table_ids.entry(table_id).or_default().add_assign(1);
644 }
645
646 for child in &stream_node.input {
647 Self::resolve_dependent_table(child, table_ids);
648 }
649 }
650
651 pub fn upstream_table_counts(&self) -> HashMap<TableId, usize> {
652 Self::upstream_table_counts_impl(self.fragments.values().map(|fragment| &fragment.nodes))
653 }
654
655 pub fn upstream_table_counts_impl(
657 fragment_nodes: impl Iterator<Item = &StreamNode>,
658 ) -> HashMap<TableId, usize> {
659 let mut table_ids = HashMap::new();
660 fragment_nodes.for_each(|node| {
661 Self::resolve_dependent_table(node, &mut table_ids);
662 });
663
664 table_ids
665 }
666
667 pub fn mv_table_id(&self) -> Option<TableId> {
668 self.fragments
669 .values()
670 .flat_map(|f| f.state_table_ids.iter().copied())
671 .find(|table_id| self.stream_job_id.is_mv_table_id(*table_id))
672 }
673
674 pub fn collect_tables(fragments: impl Iterator<Item = &Fragment>) -> BTreeMap<TableId, Table> {
675 let mut tables = BTreeMap::new();
676 for fragment in fragments {
677 stream_graph_visitor::visit_stream_node_tables_inner(
678 &mut fragment.nodes.clone(),
679 false,
680 true,
681 |table, _| {
682 let table_id = table.id;
683 tables
684 .try_insert(table_id, table.clone())
685 .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
686 },
687 );
688 }
689 tables
690 }
691
692 pub fn internal_table_ids(&self) -> Vec<TableId> {
694 self.fragments
695 .values()
696 .flat_map(|f| f.state_table_ids.iter().copied())
697 .filter(|&t| !self.stream_job_id.is_mv_table_id(t))
698 .collect_vec()
699 }
700
701 pub fn all_table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
703 self.fragments
704 .values()
705 .flat_map(|f| f.state_table_ids.clone())
706 }
707}
708
709#[derive(Debug, Display, Clone, Copy, PartialEq, Eq)]
710pub enum BackfillUpstreamType {
711 MView,
712 Values,
713 Source,
714 LocalityProvider,
715}
716
717impl BackfillUpstreamType {
718 pub fn from_fragment_type_mask(mask: FragmentTypeMask) -> Self {
719 let is_mview = mask.contains(FragmentTypeFlag::StreamScan);
720 let is_values = mask.contains(FragmentTypeFlag::Values);
721 let is_source = mask.contains(FragmentTypeFlag::SourceScan);
722 let is_locality_provider = mask.contains(FragmentTypeFlag::LocalityProvider);
723
724 debug_assert!(
727 is_mview as u8 + is_values as u8 + is_source as u8 + is_locality_provider as u8 == 1,
728 "a backfill fragment should either be mview, value, source, or locality provider, found {:?}",
729 mask
730 );
731
732 if is_mview {
733 BackfillUpstreamType::MView
734 } else if is_values {
735 BackfillUpstreamType::Values
736 } else if is_source {
737 BackfillUpstreamType::Source
738 } else if is_locality_provider {
739 BackfillUpstreamType::LocalityProvider
740 } else {
741 unreachable!("invalid fragment type mask: {:?}", mask);
742 }
743 }
744}