1use std::collections::HashMap;
2
3use sea_orm::{ConnectionTrait, FromQueryResult, Statement};
4use sea_orm_migration::prelude::*;
5
6const STREAMING_PARALLELISM: &str = "streaming_parallelism";
19const STREAMING_PARALLELISM_FOR_TABLE: &str = "streaming_parallelism_for_table";
20const STREAMING_PARALLELISM_FOR_SOURCE: &str = "streaming_parallelism_for_source";
21const STREAMING_PARALLELISM_FOR_SINK: &str = "streaming_parallelism_for_sink";
22const STREAMING_PARALLELISM_FOR_INDEX: &str = "streaming_parallelism_for_index";
23const STREAMING_PARALLELISM_FOR_MATERIALIZED_VIEW: &str =
24 "streaming_parallelism_for_materialized_view";
25const LEGACY_ADAPTIVE_PARALLELISM_STRATEGY: &str = "adaptive_parallelism_strategy";
26const LEGACY_STREAMING_PARALLELISM_STRATEGY: &str = "streaming_parallelism_strategy";
27const LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_TABLE: &str =
28 "streaming_parallelism_strategy_for_table";
29const LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SOURCE: &str =
30 "streaming_parallelism_strategy_for_source";
31const LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SINK: &str =
32 "streaming_parallelism_strategy_for_sink";
33const LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_INDEX: &str =
34 "streaming_parallelism_strategy_for_index";
35const LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_MATERIALIZED_VIEW: &str =
36 "streaming_parallelism_strategy_for_materialized_view";
37
38#[derive(DeriveMigrationName)]
39pub struct Migration;
40
41#[async_trait::async_trait]
42impl MigrationTrait for Migration {
43 async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
44 let legacy_system_strategy = load_legacy_system_strategy(manager)
45 .await?
46 .unwrap_or_else(default_legacy_system_strategy);
47 migrate_legacy_streaming_parallelism_session_params(manager, legacy_system_strategy)
48 .await?;
49 migrate_legacy_streaming_job_strategy(manager, legacy_system_strategy).await?;
50 delete_legacy_system_strategy(manager).await?;
51 Ok(())
52 }
53
54 async fn down(&self, _manager: &SchemaManager) -> Result<(), DbErr> {
55 Err(DbErr::Migration(
56 "cannot rollback legacy streaming parallelism session parameter migration".to_owned(),
57 ))?
58 }
59}
60
61async fn migrate_legacy_streaming_parallelism_session_params(
62 manager: &SchemaManager<'_>,
63 legacy_system_strategy: AdaptiveParallelismStrategy,
64) -> Result<(), DbErr> {
65 let conn = manager.get_connection();
66 let database_backend = conn.get_database_backend();
67
68 let (sql, values) = Query::select()
69 .columns([SessionParameter::Name, SessionParameter::Value])
70 .from(SessionParameter::Table)
71 .to_owned()
72 .build_any(&*database_backend.get_query_builder());
73 let rows = conn
74 .query_all(Statement::from_sql_and_values(
75 database_backend,
76 sql,
77 values,
78 ))
79 .await?;
80 let params = rows
81 .into_iter()
82 .map(|row| SessionParameterRow::from_query_result(&row, ""))
83 .collect::<Result<Vec<_>, _>>()?;
84
85 if !params
86 .iter()
87 .any(|param| is_migratable_streaming_parallelism_session_param(¶m.name))
88 {
89 if legacy_system_strategy == default_legacy_system_strategy() {
90 return Ok(());
91 }
92
93 manager
94 .exec_stmt(
95 Query::insert()
96 .into_table(SessionParameter::Table)
97 .columns([SessionParameter::Name, SessionParameter::Value])
98 .values_panic([
99 STREAMING_PARALLELISM.into(),
100 migrate_legacy_global_parallelism(
101 ConfigParallelism::Default,
102 ConfigAdaptiveParallelismStrategy::Default,
103 legacy_system_strategy,
104 )
105 .to_string()
106 .into(),
107 ])
108 .on_conflict(
109 sea_query::OnConflict::column(SessionParameter::Name)
110 .update_column(SessionParameter::Value)
111 .to_owned(),
112 )
113 .to_owned(),
114 )
115 .await?;
116 return Ok(());
117 }
118
119 let derived = derive_legacy_streaming_parallelism_params(¶ms, legacy_system_strategy);
120
121 if !derived.is_empty() {
122 let mut insert = Query::insert();
123 insert
124 .into_table(SessionParameter::Table)
125 .columns([SessionParameter::Name, SessionParameter::Value])
126 .on_conflict(
127 sea_query::OnConflict::column(SessionParameter::Name)
128 .update_column(SessionParameter::Value)
129 .to_owned(),
130 );
131
132 for (name, value) in derived {
133 insert.values_panic([name.into(), value.into()]);
134 }
135
136 manager.exec_stmt(insert.to_owned()).await?;
137 }
138
139 manager
140 .exec_stmt(
141 Query::delete()
142 .from_table(SessionParameter::Table)
143 .and_where(Expr::col(SessionParameter::Name).is_in([
144 LEGACY_STREAMING_PARALLELISM_STRATEGY,
145 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_TABLE,
146 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SOURCE,
147 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SINK,
148 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_INDEX,
149 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_MATERIALIZED_VIEW,
150 ]))
151 .to_owned(),
152 )
153 .await?;
154
155 Ok(())
156}
157
158async fn migrate_legacy_streaming_job_strategy(
159 manager: &SchemaManager<'_>,
160 legacy_system_strategy: AdaptiveParallelismStrategy,
161) -> Result<(), DbErr> {
162 manager
166 .exec_stmt(
167 Query::update()
168 .table(StreamingJob::Table)
169 .value(
170 StreamingJob::AdaptiveParallelismStrategy,
171 legacy_system_strategy.to_string(),
172 )
173 .and_where(Expr::col(StreamingJob::AdaptiveParallelismStrategy).is_null())
174 .to_owned(),
175 )
176 .await?;
177
178 Ok(())
179}
180
181async fn delete_legacy_system_strategy(manager: &SchemaManager<'_>) -> Result<(), DbErr> {
182 manager
183 .exec_stmt(
184 Query::delete()
185 .from_table(SystemParameter::Table)
186 .and_where(
187 Expr::col(SystemParameter::Name).eq(LEGACY_ADAPTIVE_PARALLELISM_STRATEGY),
188 )
189 .to_owned(),
190 )
191 .await?;
192 Ok(())
193}
194
195async fn load_legacy_system_strategy(
196 manager: &SchemaManager<'_>,
197) -> Result<Option<AdaptiveParallelismStrategy>, DbErr> {
198 let conn = manager.get_connection();
199 let database_backend = conn.get_database_backend();
200 let (sql, values) = Query::select()
201 .column(SystemParameter::Value)
202 .from(SystemParameter::Table)
203 .and_where(Expr::col(SystemParameter::Name).eq(LEGACY_ADAPTIVE_PARALLELISM_STRATEGY))
204 .to_owned()
205 .build_any(&*database_backend.get_query_builder());
206 let rows = conn
207 .query_all(Statement::from_sql_and_values(
208 database_backend,
209 sql,
210 values,
211 ))
212 .await?;
213
214 let Some(row) = rows.into_iter().next() else {
215 return Ok(None);
216 };
217 let row = SystemParameterRow::from_query_result(&row, "")?;
218
219 Ok(parse_adaptive_parallelism_strategy(&row.value))
220}
221
222fn default_legacy_system_strategy() -> AdaptiveParallelismStrategy {
223 AdaptiveParallelismStrategy::Bounded(64)
225}
226
227fn is_legacy_streaming_parallelism_strategy_param(name: &str) -> bool {
228 matches!(
229 name,
230 LEGACY_STREAMING_PARALLELISM_STRATEGY
231 | LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_TABLE
232 | LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SOURCE
233 | LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SINK
234 | LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_INDEX
235 | LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_MATERIALIZED_VIEW
236 )
237}
238
239fn is_streaming_parallelism_param(name: &str) -> bool {
240 matches!(
241 name,
242 STREAMING_PARALLELISM
243 | STREAMING_PARALLELISM_FOR_TABLE
244 | STREAMING_PARALLELISM_FOR_SOURCE
245 | STREAMING_PARALLELISM_FOR_SINK
246 | STREAMING_PARALLELISM_FOR_INDEX
247 | STREAMING_PARALLELISM_FOR_MATERIALIZED_VIEW
248 )
249}
250
251fn is_migratable_streaming_parallelism_session_param(name: &str) -> bool {
252 is_streaming_parallelism_param(name) || is_legacy_streaming_parallelism_strategy_param(name)
253}
254
255fn derive_legacy_streaming_parallelism_params(
256 params: &[SessionParameterRow],
257 legacy_system_strategy: AdaptiveParallelismStrategy,
258) -> HashMap<String, String> {
259 let param_map = params
260 .iter()
261 .map(|param| (param.name.as_str(), param.value.as_str()))
262 .collect::<HashMap<_, _>>();
263
264 let global_parallelism = parse_parallelism(
265 param_map.get(STREAMING_PARALLELISM).copied(),
266 ConfigParallelism::Default,
267 );
268 let global_strategy = parse_legacy_strategy(
269 param_map
270 .get(LEGACY_STREAMING_PARALLELISM_STRATEGY)
271 .copied(),
272 ConfigAdaptiveParallelismStrategy::Default,
273 );
274
275 let mut derived = HashMap::new();
276 if should_materialize_global_parallelism(
277 global_parallelism,
278 global_strategy,
279 legacy_system_strategy,
280 ) {
281 derived.insert(
282 STREAMING_PARALLELISM.to_owned(),
283 migrate_legacy_global_parallelism(
284 global_parallelism,
285 global_strategy,
286 legacy_system_strategy,
287 )
288 .to_string(),
289 );
290 }
291
292 for (parallelism_key, strategy_key) in [
293 (
294 STREAMING_PARALLELISM_FOR_TABLE,
295 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_TABLE,
296 ),
297 (
298 STREAMING_PARALLELISM_FOR_SOURCE,
299 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SOURCE,
300 ),
301 (
302 STREAMING_PARALLELISM_FOR_SINK,
303 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SINK,
304 ),
305 (
306 STREAMING_PARALLELISM_FOR_INDEX,
307 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_INDEX,
308 ),
309 (
310 STREAMING_PARALLELISM_FOR_MATERIALIZED_VIEW,
311 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_MATERIALIZED_VIEW,
312 ),
313 ] {
314 let implicit_legacy_strategy = default_legacy_strategy_for_type(parallelism_key);
315 let specific_strategy_value = param_map.get(strategy_key).copied();
316 let specific_parallelism = parse_parallelism(
317 param_map.get(parallelism_key).copied(),
318 ConfigParallelism::Default,
319 );
320 let specific_strategy =
321 parse_legacy_strategy(specific_strategy_value, implicit_legacy_strategy);
322 let explicit_default_inherits_global = matches!(
323 specific_strategy_value,
324 Some(value) if value.eq_ignore_ascii_case("default")
325 ) && !matches!(
326 implicit_legacy_strategy,
327 ConfigAdaptiveParallelismStrategy::Default
328 );
329 derived.insert(
330 parallelism_key.to_owned(),
331 migrate_legacy_type_parallelism(
332 specific_parallelism,
333 specific_strategy,
334 explicit_default_inherits_global,
335 global_parallelism,
336 global_strategy,
337 legacy_system_strategy,
338 )
339 .to_string(),
340 );
341 }
342
343 derived
344}
345
346fn should_materialize_global_parallelism(
347 global_parallelism: ConfigParallelism,
348 global_strategy: ConfigAdaptiveParallelismStrategy,
349 legacy_system_strategy: AdaptiveParallelismStrategy,
350) -> bool {
351 !matches!(global_parallelism, ConfigParallelism::Default)
352 || !matches!(global_strategy, ConfigAdaptiveParallelismStrategy::Default)
353 || legacy_system_strategy != default_legacy_system_strategy()
354}
355
356fn default_legacy_strategy_for_type(_parallelism_key: &str) -> ConfigAdaptiveParallelismStrategy {
357 ConfigAdaptiveParallelismStrategy::Default
358}
359
360fn parse_parallelism(value: Option<&str>, default: ConfigParallelism) -> ConfigParallelism {
361 value.and_then(parse_config_parallelism).unwrap_or(default)
362}
363
364fn parse_legacy_strategy(
365 value: Option<&str>,
366 default: ConfigAdaptiveParallelismStrategy,
367) -> ConfigAdaptiveParallelismStrategy {
368 value
369 .and_then(parse_config_adaptive_parallelism_strategy)
370 .unwrap_or(default)
371}
372
373fn parse_config_parallelism(value: &str) -> Option<ConfigParallelism> {
374 if value.eq_ignore_ascii_case("default") {
375 return Some(ConfigParallelism::Default);
376 }
377 if value.eq_ignore_ascii_case("adaptive") || value.eq_ignore_ascii_case("auto") {
378 return Some(ConfigParallelism::Adaptive);
379 }
380 if let Some(strategy) = parse_adaptive_parallelism_strategy(value) {
381 return Some(match strategy {
382 AdaptiveParallelismStrategy::Auto | AdaptiveParallelismStrategy::Full => {
383 ConfigParallelism::Adaptive
384 }
385 AdaptiveParallelismStrategy::Bounded(n) => ConfigParallelism::Bounded(n as u64),
386 AdaptiveParallelismStrategy::Ratio(r) => ConfigParallelism::Ratio(r),
387 });
388 }
389
390 let parsed = value.parse::<u64>().ok()?;
391 Some(if parsed == 0 {
392 ConfigParallelism::Adaptive
393 } else {
394 ConfigParallelism::Fixed(parsed)
395 })
396}
397
398fn parse_config_adaptive_parallelism_strategy(
399 value: &str,
400) -> Option<ConfigAdaptiveParallelismStrategy> {
401 if value.eq_ignore_ascii_case("default") {
402 return Some(ConfigAdaptiveParallelismStrategy::Default);
403 }
404 Some(match parse_adaptive_parallelism_strategy(value)? {
405 AdaptiveParallelismStrategy::Auto => ConfigAdaptiveParallelismStrategy::Auto,
406 AdaptiveParallelismStrategy::Full => ConfigAdaptiveParallelismStrategy::Full,
407 AdaptiveParallelismStrategy::Bounded(n) => {
408 ConfigAdaptiveParallelismStrategy::Bounded(n as u64)
409 }
410 AdaptiveParallelismStrategy::Ratio(r) => ConfigAdaptiveParallelismStrategy::Ratio(r),
411 })
412}
413
414fn parse_adaptive_parallelism_strategy(value: &str) -> Option<AdaptiveParallelismStrategy> {
415 if value.eq_ignore_ascii_case("auto") {
416 return Some(AdaptiveParallelismStrategy::Auto);
417 }
418 if value.eq_ignore_ascii_case("full") {
419 return Some(AdaptiveParallelismStrategy::Full);
420 }
421
422 let lower = value.to_ascii_lowercase();
423 if let Some(inner) = lower
424 .strip_prefix("bounded(")
425 .and_then(|s| s.strip_suffix(')'))
426 {
427 let n = inner.parse::<usize>().ok()?;
428 return (n > 0).then_some(AdaptiveParallelismStrategy::Bounded(n));
429 }
430 if let Some(inner) = lower
431 .strip_prefix("ratio(")
432 .and_then(|s| s.strip_suffix(')'))
433 {
434 let r = inner.parse::<f32>().ok()?;
435 return ((0.0..=1.0).contains(&r)).then_some(AdaptiveParallelismStrategy::Ratio(r));
436 }
437
438 None
439}
440
441fn migrate_legacy_global_parallelism(
442 parallelism: ConfigParallelism,
443 strategy: ConfigAdaptiveParallelismStrategy,
444 system_strategy: AdaptiveParallelismStrategy,
445) -> ConfigParallelism {
446 match parallelism {
447 ConfigParallelism::Fixed(_)
448 | ConfigParallelism::Bounded(_)
449 | ConfigParallelism::Ratio(_) => parallelism,
450 ConfigParallelism::Default | ConfigParallelism::Adaptive => {
451 legacy_strategy_to_parallelism(resolve_legacy_strategy(strategy, system_strategy))
452 }
453 }
454}
455
456fn migrate_legacy_type_parallelism(
457 specific_parallelism: ConfigParallelism,
458 specific_strategy: ConfigAdaptiveParallelismStrategy,
459 explicit_default_inherits_global: bool,
460 global_parallelism: ConfigParallelism,
461 global_strategy: ConfigAdaptiveParallelismStrategy,
462 system_strategy: AdaptiveParallelismStrategy,
463) -> ConfigParallelism {
464 match specific_parallelism {
465 ConfigParallelism::Fixed(_)
466 | ConfigParallelism::Bounded(_)
467 | ConfigParallelism::Ratio(_) => specific_parallelism,
468 ConfigParallelism::Adaptive => legacy_strategy_to_parallelism(resolve_legacy_strategy(
469 specific_strategy,
470 resolve_legacy_strategy(global_strategy, system_strategy),
471 )),
472 ConfigParallelism::Default => {
473 if matches!(global_parallelism, ConfigParallelism::Fixed(_)) {
474 ConfigParallelism::Default
475 } else if explicit_default_inherits_global {
476 legacy_strategy_to_parallelism(resolve_legacy_strategy(
477 global_strategy,
478 system_strategy,
479 ))
480 } else if matches!(
481 specific_strategy,
482 ConfigAdaptiveParallelismStrategy::Default
483 ) {
484 ConfigParallelism::Default
485 } else {
486 legacy_strategy_to_parallelism(resolve_legacy_strategy(
487 specific_strategy,
488 resolve_legacy_strategy(global_strategy, system_strategy),
489 ))
490 }
491 }
492 }
493}
494
495fn resolve_legacy_strategy(
496 strategy: ConfigAdaptiveParallelismStrategy,
497 fallback: AdaptiveParallelismStrategy,
498) -> AdaptiveParallelismStrategy {
499 match strategy {
500 ConfigAdaptiveParallelismStrategy::Default => fallback,
501 ConfigAdaptiveParallelismStrategy::Auto => AdaptiveParallelismStrategy::Auto,
502 ConfigAdaptiveParallelismStrategy::Full => AdaptiveParallelismStrategy::Full,
503 ConfigAdaptiveParallelismStrategy::Bounded(n) => {
504 AdaptiveParallelismStrategy::Bounded(n as usize)
505 }
506 ConfigAdaptiveParallelismStrategy::Ratio(r) => AdaptiveParallelismStrategy::Ratio(r),
507 }
508}
509
510fn legacy_strategy_to_parallelism(strategy: AdaptiveParallelismStrategy) -> ConfigParallelism {
511 match strategy {
512 AdaptiveParallelismStrategy::Auto | AdaptiveParallelismStrategy::Full => {
513 ConfigParallelism::Adaptive
514 }
515 AdaptiveParallelismStrategy::Bounded(n) => ConfigParallelism::Bounded(n as u64),
516 AdaptiveParallelismStrategy::Ratio(r) => ConfigParallelism::Ratio(r),
517 }
518}
519
520#[derive(Copy, Clone, Debug, PartialEq)]
521enum ConfigParallelism {
522 Default,
523 Fixed(u64),
524 Adaptive,
525 Bounded(u64),
526 Ratio(f32),
527}
528
529impl std::fmt::Display for ConfigParallelism {
530 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531 match self {
532 ConfigParallelism::Default => write!(f, "default"),
533 ConfigParallelism::Fixed(n) => write!(f, "{n}"),
534 ConfigParallelism::Adaptive => write!(f, "adaptive"),
535 ConfigParallelism::Bounded(n) => write!(f, "bounded({n})"),
536 ConfigParallelism::Ratio(r) => write!(f, "ratio({r})"),
537 }
538 }
539}
540
541#[derive(Copy, Clone, Debug, PartialEq)]
542enum ConfigAdaptiveParallelismStrategy {
543 Default,
544 Auto,
545 Full,
546 Bounded(u64),
547 Ratio(f32),
548}
549
550#[derive(Copy, Clone, Debug, PartialEq)]
551enum AdaptiveParallelismStrategy {
552 Auto,
553 Full,
554 Bounded(usize),
555 Ratio(f32),
556}
557
558impl std::fmt::Display for AdaptiveParallelismStrategy {
559 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560 match self {
561 AdaptiveParallelismStrategy::Auto => write!(f, "AUTO"),
562 AdaptiveParallelismStrategy::Full => write!(f, "FULL"),
563 AdaptiveParallelismStrategy::Bounded(n) => write!(f, "BOUNDED({n})"),
564 AdaptiveParallelismStrategy::Ratio(r) => write!(f, "RATIO({r})"),
565 }
566 }
567}
568
569#[derive(Debug, FromQueryResult)]
570struct SessionParameterRow {
571 name: String,
572 value: String,
573}
574
575#[derive(Debug, FromQueryResult)]
576struct SystemParameterRow {
577 value: String,
578}
579
580#[derive(DeriveIden)]
581enum SessionParameter {
582 Table,
583 Name,
584 Value,
585}
586
587#[derive(DeriveIden)]
588enum SystemParameter {
589 Table,
590 Name,
591 Value,
592}
593
594#[derive(DeriveIden)]
595enum StreamingJob {
596 Table,
597 AdaptiveParallelismStrategy,
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 fn session_param(name: &str, value: &str) -> SessionParameterRow {
605 SessionParameterRow {
606 name: name.to_owned(),
607 value: value.to_owned(),
608 }
609 }
610
611 #[test]
612 fn test_derive_legacy_streaming_parallelism_params_type_only_keeps_global_untouched() {
613 let derived = derive_legacy_streaming_parallelism_params(
614 &[session_param(
615 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SINK,
616 "bounded(8)",
617 )],
618 AdaptiveParallelismStrategy::Bounded(64),
619 );
620
621 assert_eq!(derived.get(STREAMING_PARALLELISM), None);
622 assert_eq!(
623 derived.get(STREAMING_PARALLELISM_FOR_SINK),
624 Some(&"bounded(8)".to_owned())
625 );
626 assert_eq!(
627 derived.get(STREAMING_PARALLELISM_FOR_TABLE),
628 Some(&"default".to_owned())
629 );
630 }
631
632 #[test]
633 fn test_derive_legacy_streaming_parallelism_params_materializes_custom_system_strategy() {
634 let derived = derive_legacy_streaming_parallelism_params(
635 &[session_param(
636 LEGACY_STREAMING_PARALLELISM_STRATEGY_FOR_SINK,
637 "bounded(8)",
638 )],
639 AdaptiveParallelismStrategy::Bounded(16),
640 );
641
642 assert_eq!(
643 derived.get(STREAMING_PARALLELISM),
644 Some(&"bounded(16)".to_owned())
645 );
646 assert_eq!(
647 derived.get(STREAMING_PARALLELISM_FOR_SINK),
648 Some(&"bounded(8)".to_owned())
649 );
650 }
651}