Skip to main content

risingwave_meta/manager/sink_coordination/
manager.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use futures::future::{BoxFuture, Either, join_all, select};
21use futures::stream::FuturesUnordered;
22use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
23use risingwave_common::bitmap::Bitmap;
24use risingwave_connector::connector_common::IcebergSinkCompactionUpdate;
25use risingwave_connector::sink::catalog::SinkId;
26use risingwave_connector::sink::{SinkCommittedEpochSubscriber, SinkError, SinkParam};
27use risingwave_pb::connector_service::coordinate_request::Msg;
28use risingwave_pb::connector_service::{CoordinateRequest, CoordinateResponse, coordinate_request};
29use rw_futures_util::pending_on_none;
30use sea_orm::DatabaseConnection;
31use thiserror_ext::AsReport;
32use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
33use tokio::sync::oneshot::{Receiver, Sender, channel};
34use tokio::sync::{mpsc, oneshot};
35use tokio::task::{JoinError, JoinHandle};
36use tokio_stream::wrappers::UnboundedReceiverStream;
37use tonic::Status;
38use tracing::{debug, error, info, warn};
39
40use crate::hummock::HummockManagerRef;
41use crate::manager::MetadataManager;
42use crate::manager::sink_coordination::SinkWriterRequestStream;
43use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
44use crate::manager::sink_coordination::handle::SinkWriterCoordinationHandle;
45
46macro_rules! send_with_err_check {
47    ($tx:expr, $msg:expr) => {
48        if $tx.send($msg).is_err() {
49            error!("unable to send msg");
50        }
51    };
52}
53
54macro_rules! send_await_with_err_check {
55    ($tx:expr, $msg:expr) => {
56        if $tx.send($msg).await.is_err() {
57            error!("unable to send msg");
58        }
59    };
60}
61
62const BOUNDED_CHANNEL_SIZE: usize = 16;
63
64enum ManagerRequest {
65    NewSinkWriter(SinkWriterCoordinationHandle),
66    StopCoordinator {
67        finish_notifier: Sender<()>,
68        /// sink id to stop. When `None`, stop all sink coordinator
69        sink_ids: Option<Vec<SinkId>>,
70    },
71}
72
73#[derive(Clone)]
74pub struct SinkCoordinatorManager {
75    request_tx: mpsc::Sender<ManagerRequest>,
76}
77fn new_committed_epoch_subscriber(
78    hummock_manager: HummockManagerRef,
79    metadata_manager: MetadataManager,
80) -> SinkCommittedEpochSubscriber {
81    Arc::new(move |sink_id| {
82        let hummock_manager = hummock_manager.clone();
83        let metadata_manager = metadata_manager.clone();
84        async move {
85            let state_table_ids = metadata_manager
86                .get_sink_state_table_ids(sink_id)
87                .await
88                .map_err(SinkError::from)?;
89            let Some(table_id) = state_table_ids.first() else {
90                return Err(anyhow!("no state table id in sink: {}", sink_id).into());
91            };
92            hummock_manager
93                .subscribe_table_committed_epoch(*table_id)
94                .await
95                .map_err(SinkError::from)
96        }
97        .boxed()
98    })
99}
100
101impl SinkCoordinatorManager {
102    pub fn start_worker(
103        db: DatabaseConnection,
104        hummock_manager: HummockManagerRef,
105        metadata_manager: MetadataManager,
106        iceberg_compact_stat_sender: UnboundedSender<IcebergSinkCompactionUpdate>,
107        await_tree_reg: await_tree::Registry,
108    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
109        let subscriber = new_committed_epoch_subscriber(hummock_manager, metadata_manager);
110        Self::start_worker_with_spawn_worker({
111            move |param, manager_request_stream| {
112                let sink_id = param.sink_id;
113                let fut = CoordinatorWorker::run(
114                    param,
115                    manager_request_stream,
116                    db.clone(),
117                    subscriber.clone(),
118                    iceberg_compact_stat_sender.clone(),
119                );
120                tokio::spawn(
121                    await_tree_reg
122                        .register_derived_root(format!("Sink Coordinator {sink_id}"))
123                        .instrument(fut),
124                )
125            }
126        })
127    }
128
129    fn start_worker_with_spawn_worker(
130        spawn_coordinator_worker: impl SpawnCoordinatorFn,
131    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
132        let (request_tx, request_rx) = mpsc::channel(BOUNDED_CHANNEL_SIZE);
133        let (shutdown_tx, shutdown_rx) = channel();
134        let worker = ManagerWorker::new(request_rx, shutdown_rx);
135        let join_handle = tokio::spawn(worker.execute(spawn_coordinator_worker));
136        (
137            SinkCoordinatorManager { request_tx },
138            (join_handle, shutdown_tx),
139        )
140    }
141
142    pub async fn handle_new_request(
143        &self,
144        mut request_stream: SinkWriterRequestStream,
145    ) -> Result<impl Stream<Item = Result<CoordinateResponse, Status>> + use<>, Status> {
146        let (param, vnode_bitmap) = match request_stream.try_next().await? {
147            Some(CoordinateRequest {
148                msg:
149                    Some(Msg::StartRequest(coordinate_request::StartCoordinationRequest {
150                        param: Some(param),
151                        vnode_bitmap: Some(vnode_bitmap),
152                    })),
153            }) => (SinkParam::from_proto(param), Bitmap::from(&vnode_bitmap)),
154            msg => {
155                return Err(Status::invalid_argument(format!(
156                    "expected CoordinateRequest::StartRequest in the first request, get {:?}",
157                    msg
158                )));
159            }
160        };
161        let (response_tx, response_rx) = mpsc::unbounded_channel();
162        self.request_tx
163            .send(ManagerRequest::NewSinkWriter(
164                SinkWriterCoordinationHandle::new(request_stream, response_tx, param, vnode_bitmap),
165            ))
166            .await
167            .map_err(|_| {
168                Status::unavailable(
169                    "unable to send to sink manager worker. The worker may have stopped",
170                )
171            })?;
172
173        Ok(UnboundedReceiverStream::new(response_rx))
174    }
175
176    async fn stop_coordinator(&self, sink_ids: Option<Vec<SinkId>>) {
177        let (tx, rx) = channel();
178        send_await_with_err_check!(
179            self.request_tx,
180            ManagerRequest::StopCoordinator {
181                finish_notifier: tx,
182                sink_ids: sink_ids.clone(),
183            }
184        );
185        if rx.await.is_err() {
186            error!("fail to wait for resetting sink manager worker");
187        }
188        info!("successfully stop coordinator: {:?}", sink_ids);
189    }
190
191    pub async fn reset(&self) {
192        self.stop_coordinator(None).await;
193    }
194
195    pub async fn stop_sink_coordinator(&self, sink_ids: Vec<SinkId>) {
196        self.stop_coordinator(Some(sink_ids)).await;
197    }
198}
199
200struct CoordinatorWorkerHandle {
201    /// Sender to coordinator worker. Drop the sender as a stop signal
202    request_sender: Option<UnboundedSender<SinkWriterCoordinationHandle>>,
203    /// Notify when the coordinator worker stops
204    finish_notifiers: Vec<Sender<()>>,
205}
206
207struct ManagerWorker {
208    request_rx: mpsc::Receiver<ManagerRequest>,
209    // Make it option so that it can be polled with &mut SinkManagerWorker
210    shutdown_rx: Receiver<()>,
211
212    running_coordinator_worker_join_handles:
213        FuturesUnordered<BoxFuture<'static, (SinkId, Result<(), JoinError>)>>,
214    running_coordinator_worker: HashMap<SinkId, CoordinatorWorkerHandle>,
215}
216
217enum ManagerEvent {
218    NewRequest(ManagerRequest),
219    CoordinatorWorkerFinished {
220        sink_id: SinkId,
221        join_result: Result<(), JoinError>,
222    },
223}
224
225trait SpawnCoordinatorFn = FnMut(SinkParam, UnboundedReceiver<SinkWriterCoordinationHandle>) -> JoinHandle<()>
226    + Send
227    + 'static;
228
229impl ManagerWorker {
230    fn new(request_rx: mpsc::Receiver<ManagerRequest>, shutdown_rx: Receiver<()>) -> Self {
231        ManagerWorker {
232            request_rx,
233            shutdown_rx,
234            running_coordinator_worker_join_handles: Default::default(),
235            running_coordinator_worker: Default::default(),
236        }
237    }
238
239    async fn execute(mut self, mut spawn_coordinator_worker: impl SpawnCoordinatorFn) {
240        while let Some(event) = self.next_event().await {
241            match event {
242                ManagerEvent::NewRequest(request) => match request {
243                    ManagerRequest::NewSinkWriter(request) => {
244                        self.handle_new_sink_writer(request, &mut spawn_coordinator_worker)
245                    }
246                    ManagerRequest::StopCoordinator {
247                        finish_notifier,
248                        sink_ids,
249                    } => {
250                        if let Some(sink_ids) = sink_ids {
251                            let mut rxs = Vec::with_capacity(sink_ids.len());
252                            for sink_id in sink_ids {
253                                if let Some(worker_handle) =
254                                    self.running_coordinator_worker.get_mut(&sink_id)
255                                {
256                                    let (tx, rx) = oneshot::channel();
257                                    rxs.push(rx);
258                                    worker_handle.finish_notifiers.push(tx);
259                                    if let Some(sender) = worker_handle.request_sender.take() {
260                                        // drop the sender as a signal to notify the coordinator worker
261                                        // to stop
262                                        drop(sender);
263                                    }
264                                } else {
265                                    debug!(
266                                        "sink coordinator of {} is not running, skip it",
267                                        sink_id
268                                    );
269                                }
270                            }
271                            tokio::spawn(async move {
272                                let notify_res = join_all(rxs).await;
273                                for res in notify_res {
274                                    if let Err(e) = res {
275                                        error!(
276                                            "fail to wait for resetting sink manager worker: {}",
277                                            e.as_report()
278                                        );
279                                    }
280                                }
281                                send_with_err_check!(finish_notifier, ());
282                            });
283                        } else {
284                            self.clean_up().await;
285                            send_with_err_check!(finish_notifier, ());
286                        }
287                    }
288                },
289                ManagerEvent::CoordinatorWorkerFinished {
290                    sink_id,
291                    join_result,
292                } => self.handle_coordinator_finished(sink_id, join_result),
293            }
294        }
295        self.clean_up().await;
296        info!("sink manager worker exited");
297    }
298
299    async fn next_event(&mut self) -> Option<ManagerEvent> {
300        match select(
301            select(
302                pin!(self.request_rx.recv()),
303                pin!(pending_on_none(
304                    self.running_coordinator_worker_join_handles.next()
305                )),
306            ),
307            &mut self.shutdown_rx,
308        )
309        .await
310        {
311            Either::Left((either, _)) => match either {
312                Either::Left((Some(request), _)) => Some(ManagerEvent::NewRequest(request)),
313                Either::Left((None, _)) => None,
314                Either::Right(((sink_id, join_result), _)) => {
315                    Some(ManagerEvent::CoordinatorWorkerFinished {
316                        sink_id,
317                        join_result,
318                    })
319                }
320            },
321            Either::Right(_) => None,
322        }
323    }
324
325    async fn clean_up(&mut self) {
326        info!("sink manager worker start cleaning up");
327        for worker_handle in self.running_coordinator_worker.values_mut() {
328            if let Some(sender) = worker_handle.request_sender.take() {
329                // drop the sender to notify the coordinator worker to stop
330                drop(sender);
331            }
332        }
333        while let Some((sink_id, join_result)) =
334            self.running_coordinator_worker_join_handles.next().await
335        {
336            self.handle_coordinator_finished(sink_id, join_result);
337        }
338        info!("sink manager worker finished cleaning up");
339    }
340
341    fn handle_coordinator_finished(&mut self, sink_id: SinkId, join_result: Result<(), JoinError>) {
342        let worker_handle = self
343            .running_coordinator_worker
344            .remove(&sink_id)
345            .expect("finished coordinator should have an associated worker handle");
346        for finish_notifier in worker_handle.finish_notifiers {
347            send_with_err_check!(finish_notifier, ());
348        }
349        match join_result {
350            Ok(()) => {
351                info!(
352                    id = %sink_id,
353                    "sink coordinator has gracefully finished",
354                );
355            }
356            Err(err) => {
357                error!(
358                    id = %sink_id,
359                    error = %err.as_report(),
360                    "sink coordinator finished with error",
361                );
362            }
363        }
364    }
365
366    fn handle_new_sink_writer(
367        &mut self,
368        new_writer: SinkWriterCoordinationHandle,
369        spawn_coordinator_worker: &mut impl SpawnCoordinatorFn,
370    ) {
371        let param = new_writer.param();
372        let sink_id = param.sink_id;
373
374        let handle = self
375            .running_coordinator_worker
376            .entry(param.sink_id)
377            .or_insert_with(|| {
378                // Launch the coordinator worker task if it is the first
379                let (request_tx, request_rx) = unbounded_channel();
380                let join_handle = spawn_coordinator_worker(param.clone(), request_rx);
381                self.running_coordinator_worker_join_handles.push(
382                    join_handle
383                        .map(move |join_result| (sink_id, join_result))
384                        .boxed(),
385                );
386                CoordinatorWorkerHandle {
387                    request_sender: Some(request_tx),
388                    finish_notifiers: Vec::new(),
389                }
390            });
391
392        if let Some(sender) = handle.request_sender.as_mut() {
393            send_with_err_check!(sender, new_writer);
394        } else {
395            warn!(
396                "handle a new request while the sink coordinator is being stopped: {:?}",
397                param
398            );
399            new_writer.abort(Status::internal("the sink is being stopped"));
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use std::future::{Future, poll_fn};
407    use std::pin::pin;
408    use std::sync::Arc;
409    use std::sync::atomic::AtomicI32;
410    use std::task::Poll;
411
412    use anyhow::anyhow;
413    use async_trait::async_trait;
414    use futures::future::{join, try_join};
415    use futures::{FutureExt, StreamExt, TryFutureExt};
416    use itertools::Itertools;
417    use rand::seq::SliceRandom;
418    use risingwave_common::bitmap::BitmapBuilder;
419    use risingwave_common::hash::VirtualNode;
420    use risingwave_connector::sink::catalog::{SinkId, SinkType};
421    use risingwave_connector::sink::{
422        SinglePhaseCommitCoordinator, SinkCommitCoordinator, SinkError, SinkParam,
423        TwoPhaseCommitCoordinator,
424    };
425    use risingwave_meta_model::SinkSchemachange;
426    use risingwave_pb::connector_service::SinkMetadata;
427    use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata};
428    use risingwave_pb::data::PbDataType;
429    use risingwave_pb::data::data_type::PbTypeName;
430    use risingwave_pb::plan_common::PbField;
431    use risingwave_pb::stream_plan::sink_schema_change::Op as SinkSchemachangeOp;
432    use risingwave_pb::stream_plan::{PbSinkAddColumnsOp, PbSinkSchemaChange};
433    use risingwave_rpc_client::CoordinatorStreamHandle;
434    use sea_orm::{ConnectionTrait, Database, DatabaseConnection};
435    use tokio::sync::mpsc::unbounded_channel;
436    use tokio_stream::wrappers::ReceiverStream;
437
438    use crate::manager::sink_coordination::SinkCoordinatorManager;
439    use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
440    use crate::manager::sink_coordination::manager::SinkCommittedEpochSubscriber;
441
442    struct MockSinglePhaseCoordinator<
443        C,
444        F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>,
445    > {
446        context: C,
447        f: F,
448    }
449
450    impl<
451        C: Send + 'static,
452        F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send + 'static,
453    > MockSinglePhaseCoordinator<C, F>
454    {
455        fn new_coordinator(context: C, f: F) -> SinkCommitCoordinator {
456            SinkCommitCoordinator::SinglePhase(Box::new(MockSinglePhaseCoordinator { context, f }))
457        }
458    }
459
460    #[async_trait]
461    impl<C: Send, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send>
462        SinglePhaseCommitCoordinator for MockSinglePhaseCoordinator<C, F>
463    {
464        async fn init(&mut self) -> risingwave_connector::sink::Result<()> {
465            Ok(())
466        }
467
468        async fn commit_data(
469            &mut self,
470            epoch: u64,
471            metadata: Vec<SinkMetadata>,
472        ) -> risingwave_connector::sink::Result<()> {
473            (self.f)(epoch, metadata, &mut self.context)
474        }
475
476        async fn commit_schema_change(
477            &mut self,
478            _epoch: u64,
479            _schema_change: PbSinkSchemaChange,
480        ) -> risingwave_connector::sink::Result<()> {
481            unreachable!()
482        }
483    }
484
485    #[tokio::test]
486    async fn test_basic() {
487        let db = prepare_db_backend().await;
488
489        let param = SinkParam {
490            sink_id: SinkId::from(1),
491            sink_name: "test".into(),
492            properties: Default::default(),
493            columns: vec![],
494            downstream_pk: None,
495            sink_type: SinkType::AppendOnly,
496            ignore_delete: false,
497            format_desc: None,
498            db_name: "test".into(),
499            sink_from_name: "test".into(),
500        };
501
502        let epoch0 = 232;
503        let epoch1 = 233;
504        let epoch2 = 234;
505
506        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
507        all_vnode.shuffle(&mut rand::rng());
508        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
509        let build_bitmap = |indexes: &[usize]| {
510            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
511            for i in indexes {
512                builder.set(*i, true);
513            }
514            builder.finish()
515        };
516        let vnode1 = build_bitmap(first);
517        let vnode2 = build_bitmap(second);
518
519        let metadata = [
520            [vec![1u8, 2u8], vec![3u8, 4u8]],
521            [vec![5u8, 6u8], vec![7u8, 8u8]],
522        ];
523        let sender = Arc::new(tokio::sync::Mutex::new(None));
524        let mock_subscriber: SinkCommittedEpochSubscriber = {
525            let captured_sender = sender.clone();
526            Arc::new(move |_sink_id: SinkId| {
527                let (sender, receiver) = unbounded_channel();
528                let captured_sender = captured_sender.clone();
529                async move {
530                    let mut guard = captured_sender.lock().await;
531                    *guard = Some(sender);
532                    Ok((1, receiver))
533                }
534                .boxed()
535            })
536        };
537
538        let (manager, (_join_handle, _stop_tx)) =
539            SinkCoordinatorManager::start_worker_with_spawn_worker({
540                let expected_param = param.clone();
541                let metadata = metadata.clone();
542                let db = db.clone();
543                move |param, new_writer_rx| {
544                    let metadata = metadata.clone();
545                    let expected_param = expected_param.clone();
546                    let db = db.clone();
547                    tokio::spawn({
548                        let subscriber = mock_subscriber.clone();
549                        async move {
550                            // validate the start request
551                            assert_eq!(param, expected_param);
552                            CoordinatorWorker::execute_coordinator(
553                                db,
554                                param.clone(),
555                                new_writer_rx,
556                                MockSinglePhaseCoordinator::new_coordinator(
557                                    0,
558                                    move |epoch, metadata_list, count: &mut usize| {
559                                        *count += 1;
560                                        let mut metadata_list =
561                                            metadata_list
562                                                .into_iter()
563                                                .map(|metadata| match metadata {
564                                                    SinkMetadata {
565                                                        metadata:
566                                                            Some(Metadata::Serialized(
567                                                                SerializedMetadata { metadata },
568                                                            )),
569                                                    } => metadata,
570                                                    _ => unreachable!(),
571                                                })
572                                                .collect_vec();
573                                        metadata_list.sort();
574                                        match *count {
575                                            1 => {
576                                                assert_eq!(epoch, epoch1);
577                                                assert_eq!(2, metadata_list.len());
578                                                assert_eq!(metadata[0][0], metadata_list[0]);
579                                                assert_eq!(metadata[0][1], metadata_list[1]);
580                                            }
581                                            2 => {
582                                                assert_eq!(epoch, epoch2);
583                                                assert_eq!(2, metadata_list.len());
584                                                assert_eq!(metadata[1][0], metadata_list[0]);
585                                                assert_eq!(metadata[1][1], metadata_list[1]);
586                                            }
587                                            _ => unreachable!(),
588                                        }
589                                        Ok(())
590                                    },
591                                ),
592                                subscriber.clone(),
593                            )
594                            .await;
595                        }
596                    })
597                }
598            });
599
600        let build_client = |vnode| async {
601            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
602                Ok(tonic::Response::new(
603                    manager
604                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
605                        .await
606                        .unwrap()
607                        .boxed(),
608                ))
609            })
610            .await
611            .unwrap()
612            .0
613        };
614
615        let (mut client1, mut client2) =
616            join(build_client(vnode1), pin!(build_client(vnode2))).await;
617
618        let (aligned_epoch1, aligned_epoch2) = try_join(
619            client1.align_initial_epoch(epoch0),
620            client2.align_initial_epoch(epoch1),
621        )
622        .await
623        .unwrap();
624        assert_eq!(aligned_epoch1, epoch1);
625        assert_eq!(aligned_epoch2, epoch1);
626
627        {
628            // commit epoch1
629            let mut commit_future = pin!(
630                client2
631                    .commit(
632                        epoch1,
633                        SinkMetadata {
634                            metadata: Some(Metadata::Serialized(SerializedMetadata {
635                                metadata: metadata[0][1].clone(),
636                            })),
637                        },
638                        None,
639                    )
640                    .map(|result| result.unwrap())
641            );
642            assert!(
643                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
644                    .await
645                    .is_pending()
646            );
647            join(
648                commit_future,
649                client1
650                    .commit(
651                        epoch1,
652                        SinkMetadata {
653                            metadata: Some(Metadata::Serialized(SerializedMetadata {
654                                metadata: metadata[0][0].clone(),
655                            })),
656                        },
657                        None,
658                    )
659                    .map(|result| result.unwrap()),
660            )
661            .await;
662        }
663
664        // commit epoch2
665        let mut commit_future = pin!(
666            client1
667                .commit(
668                    epoch2,
669                    SinkMetadata {
670                        metadata: Some(Metadata::Serialized(SerializedMetadata {
671                            metadata: metadata[1][0].clone(),
672                        })),
673                    },
674                    None,
675                )
676                .map(|result| result.unwrap())
677        );
678        assert!(
679            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
680                .await
681                .is_pending()
682        );
683        join(
684            commit_future,
685            client2
686                .commit(
687                    epoch2,
688                    SinkMetadata {
689                        metadata: Some(Metadata::Serialized(SerializedMetadata {
690                            metadata: metadata[1][1].clone(),
691                        })),
692                    },
693                    None,
694                )
695                .map(|result| result.unwrap()),
696        )
697        .await;
698    }
699
700    #[tokio::test]
701    async fn test_single_writer() {
702        let db = prepare_db_backend().await;
703        let param = SinkParam {
704            sink_id: SinkId::from(1),
705            sink_name: "test".into(),
706            properties: Default::default(),
707            columns: vec![],
708            downstream_pk: None,
709            sink_type: SinkType::AppendOnly,
710            ignore_delete: false,
711            format_desc: None,
712            db_name: "test".into(),
713            sink_from_name: "test".into(),
714        };
715
716        let epoch1 = 233;
717        let epoch2 = 234;
718
719        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
720        let build_bitmap = |indexes: &[usize]| {
721            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
722            for i in indexes {
723                builder.set(*i, true);
724            }
725            builder.finish()
726        };
727        let vnode = build_bitmap(&all_vnode);
728
729        let metadata = [vec![1u8, 2u8], vec![3u8, 4u8]];
730        let sender = Arc::new(tokio::sync::Mutex::new(None));
731        let mock_subscriber: SinkCommittedEpochSubscriber = {
732            let captured_sender = sender.clone();
733            Arc::new(move |_sink_id: SinkId| {
734                let (sender, receiver) = unbounded_channel();
735                let captured_sender = captured_sender.clone();
736                async move {
737                    let mut guard = captured_sender.lock().await;
738                    *guard = Some(sender);
739                    Ok((1, receiver))
740                }
741                .boxed()
742            })
743        };
744        let (manager, (_join_handle, _stop_tx)) =
745            SinkCoordinatorManager::start_worker_with_spawn_worker({
746                let expected_param = param.clone();
747                let metadata = metadata.clone();
748                let db = db.clone();
749                move |param, new_writer_rx| {
750                    let metadata = metadata.clone();
751                    let expected_param = expected_param.clone();
752                    let db = db.clone();
753                    tokio::spawn({
754                        let subscriber = mock_subscriber.clone();
755                        async move {
756                            // validate the start request
757                            assert_eq!(param, expected_param);
758                            CoordinatorWorker::execute_coordinator(
759                                db,
760                                param.clone(),
761                                new_writer_rx,
762                                MockSinglePhaseCoordinator::new_coordinator(
763                                    0,
764                                    move |epoch, metadata_list, count: &mut usize| {
765                                        *count += 1;
766                                        let mut metadata_list =
767                                            metadata_list
768                                                .into_iter()
769                                                .map(|metadata| match metadata {
770                                                    SinkMetadata {
771                                                        metadata:
772                                                            Some(Metadata::Serialized(
773                                                                SerializedMetadata { metadata },
774                                                            )),
775                                                    } => metadata,
776                                                    _ => unreachable!(),
777                                                })
778                                                .collect_vec();
779                                        metadata_list.sort();
780                                        match *count {
781                                            1 => {
782                                                assert_eq!(epoch, epoch1);
783                                                assert_eq!(1, metadata_list.len());
784                                                assert_eq!(metadata[0], metadata_list[0]);
785                                            }
786                                            2 => {
787                                                assert_eq!(epoch, epoch2);
788                                                assert_eq!(1, metadata_list.len());
789                                                assert_eq!(metadata[1], metadata_list[0]);
790                                            }
791                                            _ => unreachable!(),
792                                        }
793                                        Ok(())
794                                    },
795                                ),
796                                subscriber.clone(),
797                            )
798                            .await;
799                        }
800                    })
801                }
802            });
803
804        let build_client = |vnode| async {
805            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
806                Ok(tonic::Response::new(
807                    manager
808                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
809                        .await
810                        .unwrap()
811                        .boxed(),
812                ))
813            })
814            .await
815            .unwrap()
816            .0
817        };
818
819        let mut client = build_client(vnode).await;
820
821        let aligned_epoch = client.align_initial_epoch(epoch1).await.unwrap();
822        assert_eq!(aligned_epoch, epoch1);
823
824        client
825            .commit(
826                epoch1,
827                SinkMetadata {
828                    metadata: Some(Metadata::Serialized(SerializedMetadata {
829                        metadata: metadata[0].clone(),
830                    })),
831                },
832                None,
833            )
834            .await
835            .unwrap();
836
837        client
838            .commit(
839                epoch2,
840                SinkMetadata {
841                    metadata: Some(Metadata::Serialized(SerializedMetadata {
842                        metadata: metadata[1].clone(),
843                    })),
844                },
845                None,
846            )
847            .await
848            .unwrap();
849    }
850
851    #[tokio::test]
852    async fn test_partial_commit() {
853        let db = prepare_db_backend().await;
854        let param = SinkParam {
855            sink_id: SinkId::from(1),
856            sink_name: "test".into(),
857            properties: Default::default(),
858            columns: vec![],
859            downstream_pk: None,
860            sink_type: SinkType::AppendOnly,
861            ignore_delete: false,
862            format_desc: None,
863            db_name: "test".into(),
864            sink_from_name: "test".into(),
865        };
866
867        let epoch = 233;
868
869        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
870        all_vnode.shuffle(&mut rand::rng());
871        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
872        let build_bitmap = |indexes: &[usize]| {
873            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
874            for i in indexes {
875                builder.set(*i, true);
876            }
877            builder.finish()
878        };
879        let vnode1 = build_bitmap(first);
880        let vnode2 = build_bitmap(second);
881
882        let sender = Arc::new(tokio::sync::Mutex::new(None));
883        let mock_subscriber: SinkCommittedEpochSubscriber = {
884            let captured_sender = sender.clone();
885            Arc::new(move |_sink_id: SinkId| {
886                let (sender, receiver) = unbounded_channel();
887                let captured_sender = captured_sender.clone();
888                async move {
889                    let mut guard = captured_sender.lock().await;
890                    *guard = Some(sender);
891                    Ok((1, receiver))
892                }
893                .boxed()
894            })
895        };
896        let (manager, (_join_handle, _stop_tx)) =
897            SinkCoordinatorManager::start_worker_with_spawn_worker({
898                let expected_param = param.clone();
899                let db = db.clone();
900                move |param, new_writer_rx| {
901                    let expected_param = expected_param.clone();
902                    let db = db.clone();
903                    tokio::spawn({
904                        let subscriber = mock_subscriber.clone();
905                        async move {
906                            // validate the start request
907                            assert_eq!(param, expected_param);
908                            CoordinatorWorker::execute_coordinator(
909                                db,
910                                param,
911                                new_writer_rx,
912                                MockSinglePhaseCoordinator::new_coordinator(
913                                    (),
914                                    |_, _, _| unreachable!(),
915                                ),
916                                subscriber.clone(),
917                            )
918                            .await;
919                        }
920                    })
921                }
922            });
923
924        let build_client = |vnode| async {
925            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
926                Ok(tonic::Response::new(
927                    manager
928                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
929                        .await
930                        .unwrap()
931                        .boxed(),
932                ))
933            })
934            .await
935            .unwrap()
936            .0
937        };
938
939        let (mut client1, client2) = join(build_client(vnode1), build_client(vnode2)).await;
940
941        // commit epoch
942        let mut commit_future = pin!(client1.commit(
943            epoch,
944            SinkMetadata {
945                metadata: Some(Metadata::Serialized(SerializedMetadata {
946                    metadata: vec![],
947                })),
948            },
949            None,
950        ));
951        assert!(
952            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
953                .await
954                .is_pending()
955        );
956        drop(client2);
957        assert!(commit_future.await.is_err());
958    }
959
960    #[tokio::test]
961    async fn test_fail_commit() {
962        let db = prepare_db_backend().await;
963        let param = SinkParam {
964            sink_id: SinkId::from(1),
965            sink_name: "test".into(),
966            properties: Default::default(),
967            columns: vec![],
968            downstream_pk: None,
969            sink_type: SinkType::AppendOnly,
970            ignore_delete: false,
971            format_desc: None,
972            db_name: "test".into(),
973            sink_from_name: "test".into(),
974        };
975
976        let epoch = 233;
977
978        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
979        all_vnode.shuffle(&mut rand::rng());
980        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
981        let build_bitmap = |indexes: &[usize]| {
982            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
983            for i in indexes {
984                builder.set(*i, true);
985            }
986            builder.finish()
987        };
988        let vnode1 = build_bitmap(first);
989        let vnode2 = build_bitmap(second);
990        let sender = Arc::new(tokio::sync::Mutex::new(None));
991        let mock_subscriber: SinkCommittedEpochSubscriber = {
992            let captured_sender = sender.clone();
993            Arc::new(move |_sink_id: SinkId| {
994                let (sender, receiver) = unbounded_channel();
995                let captured_sender = captured_sender.clone();
996                async move {
997                    let mut guard = captured_sender.lock().await;
998                    *guard = Some(sender);
999                    Ok((1, receiver))
1000                }
1001                .boxed()
1002            })
1003        };
1004        let (manager, (_join_handle, _stop_tx)) =
1005            SinkCoordinatorManager::start_worker_with_spawn_worker({
1006                let expected_param = param.clone();
1007                let db = db.clone();
1008                move |param, new_writer_rx| {
1009                    let expected_param = expected_param.clone();
1010                    let db = db.clone();
1011                    tokio::spawn({
1012                        let subscriber = mock_subscriber.clone();
1013                        {
1014                            async move {
1015                                // validate the start request
1016                                assert_eq!(param, expected_param);
1017                                CoordinatorWorker::execute_coordinator(
1018                                    db,
1019                                    param,
1020                                    new_writer_rx,
1021                                    MockSinglePhaseCoordinator::new_coordinator((), |_, _, _| {
1022                                        Err(SinkError::Coordinator(anyhow!("failed to commit")))
1023                                    }),
1024                                    subscriber.clone(),
1025                                )
1026                                .await;
1027                            }
1028                        }
1029                    })
1030                }
1031            });
1032
1033        let build_client = |vnode| async {
1034            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1035                Ok(tonic::Response::new(
1036                    manager
1037                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1038                        .await
1039                        .unwrap()
1040                        .boxed(),
1041                ))
1042            })
1043            .await
1044            .unwrap()
1045            .0
1046        };
1047
1048        let (mut client1, mut client2) = join(build_client(vnode1), build_client(vnode2)).await;
1049
1050        // commit epoch
1051        let mut commit_future = pin!(client1.commit(
1052            epoch,
1053            SinkMetadata {
1054                metadata: Some(Metadata::Serialized(SerializedMetadata {
1055                    metadata: vec![],
1056                })),
1057            },
1058            None,
1059        ));
1060        assert!(
1061            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1062                .await
1063                .is_pending()
1064        );
1065        let (result1, result2) = join(
1066            commit_future,
1067            client2.commit(
1068                epoch,
1069                SinkMetadata {
1070                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1071                        metadata: vec![],
1072                    })),
1073                },
1074                None,
1075            ),
1076        )
1077        .await;
1078        assert!(result1.is_err());
1079        assert!(result2.is_err());
1080    }
1081
1082    #[tokio::test]
1083    async fn test_update_vnode_bitmap() {
1084        let db = prepare_db_backend().await;
1085        let param = SinkParam {
1086            sink_id: SinkId::from(1),
1087            sink_name: "test".into(),
1088            properties: Default::default(),
1089            columns: vec![],
1090            downstream_pk: None,
1091            sink_type: SinkType::AppendOnly,
1092            ignore_delete: false,
1093            format_desc: None,
1094            db_name: "test".into(),
1095            sink_from_name: "test".into(),
1096        };
1097
1098        let epoch1 = 233;
1099        let epoch2 = 234;
1100        let epoch3 = 235;
1101        let epoch4 = 236;
1102
1103        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1104        all_vnode.shuffle(&mut rand::rng());
1105        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
1106        let build_bitmap = |indexes: &[usize]| {
1107            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1108            for i in indexes {
1109                builder.set(*i, true);
1110            }
1111            builder.finish()
1112        };
1113        let vnode1 = build_bitmap(first);
1114        let vnode2 = build_bitmap(second);
1115
1116        let metadata = [
1117            [vec![1u8, 2u8], vec![3u8, 4u8]],
1118            [vec![5u8, 6u8], vec![7u8, 8u8]],
1119        ];
1120
1121        let metadata_scale_out = [vec![9u8, 10u8], vec![11u8, 12u8], vec![13u8, 14u8]];
1122        let metadata_scale_in = [vec![13u8, 14u8], vec![15u8, 16u8]];
1123        let sender = Arc::new(tokio::sync::Mutex::new(None));
1124        let mock_subscriber: SinkCommittedEpochSubscriber = {
1125            let captured_sender = sender.clone();
1126            Arc::new(move |_sink_id: SinkId| {
1127                let (sender, receiver) = unbounded_channel();
1128                let captured_sender = captured_sender.clone();
1129                async move {
1130                    let mut guard = captured_sender.lock().await;
1131                    *guard = Some(sender);
1132                    Ok((1, receiver))
1133                }
1134                .boxed()
1135            })
1136        };
1137        let (manager, (_join_handle, _stop_tx)) =
1138            SinkCoordinatorManager::start_worker_with_spawn_worker({
1139                let expected_param = param.clone();
1140                let metadata = metadata.clone();
1141                let metadata_scale_out = metadata_scale_out.clone();
1142                let metadata_scale_in = metadata_scale_in.clone();
1143                let db = db.clone();
1144                move |param, new_writer_rx| {
1145                    let metadata = metadata.clone();
1146                    let metadata_scale_out = metadata_scale_out.clone();
1147                    let metadata_scale_in = metadata_scale_in.clone();
1148                    let expected_param = expected_param.clone();
1149                    let db = db.clone();
1150                    tokio::spawn({
1151                        let subscriber = mock_subscriber.clone();
1152                        async move {
1153                            // validate the start request
1154                            assert_eq!(param, expected_param);
1155                            CoordinatorWorker::execute_coordinator(
1156                                db,
1157                                param.clone(),
1158                                new_writer_rx,
1159                                MockSinglePhaseCoordinator::new_coordinator(
1160                                    0,
1161                                    move |epoch, metadata_list, count: &mut usize| {
1162                                        *count += 1;
1163                                        let mut metadata_list =
1164                                            metadata_list
1165                                                .into_iter()
1166                                                .map(|metadata| match metadata {
1167                                                    SinkMetadata {
1168                                                        metadata:
1169                                                            Some(Metadata::Serialized(
1170                                                                SerializedMetadata { metadata },
1171                                                            )),
1172                                                    } => metadata,
1173                                                    _ => unreachable!(),
1174                                                })
1175                                                .collect_vec();
1176                                        metadata_list.sort();
1177                                        let (expected_epoch, expected_metadata_list) = match *count
1178                                        {
1179                                            1 => (epoch1, metadata[0].as_slice()),
1180                                            2 => (epoch2, metadata[1].as_slice()),
1181                                            3 => (epoch3, metadata_scale_out.as_slice()),
1182                                            4 => (epoch4, metadata_scale_in.as_slice()),
1183                                            _ => unreachable!(),
1184                                        };
1185                                        assert_eq!(expected_epoch, epoch);
1186                                        assert_eq!(expected_metadata_list, &metadata_list);
1187                                        Ok(())
1188                                    },
1189                                ),
1190                                subscriber.clone(),
1191                            )
1192                            .await;
1193                        }
1194                    })
1195                }
1196            });
1197
1198        let build_client = |vnode| async {
1199            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1200                Ok(tonic::Response::new(
1201                    manager
1202                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1203                        .await
1204                        .unwrap()
1205                        .boxed(),
1206                ))
1207            })
1208            .await
1209        };
1210
1211        let ((mut client1, _), (mut client2, _)) =
1212            try_join(build_client(vnode1), pin!(build_client(vnode2)))
1213                .await
1214                .unwrap();
1215
1216        let (aligned_epoch1, aligned_epoch2) = try_join(
1217            client1.align_initial_epoch(epoch1),
1218            client2.align_initial_epoch(epoch1),
1219        )
1220        .await
1221        .unwrap();
1222        assert_eq!(aligned_epoch1, epoch1);
1223        assert_eq!(aligned_epoch2, epoch1);
1224
1225        {
1226            // commit epoch1
1227            let mut commit_future = pin!(
1228                client2
1229                    .commit(
1230                        epoch1,
1231                        SinkMetadata {
1232                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1233                                metadata: metadata[0][1].clone(),
1234                            })),
1235                        },
1236                        None,
1237                    )
1238                    .map(|result| result.unwrap())
1239            );
1240            assert!(
1241                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1242                    .await
1243                    .is_pending()
1244            );
1245            join(
1246                commit_future,
1247                client1
1248                    .commit(
1249                        epoch1,
1250                        SinkMetadata {
1251                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1252                                metadata: metadata[0][0].clone(),
1253                            })),
1254                        },
1255                        None,
1256                    )
1257                    .map(|result| result.unwrap()),
1258            )
1259            .await;
1260        }
1261
1262        let (vnode1, vnode2, vnode3) = {
1263            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1264            let (second, third) = second.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1265            (
1266                build_bitmap(first),
1267                build_bitmap(second),
1268                build_bitmap(third),
1269            )
1270        };
1271
1272        let mut build_client3_future = pin!(build_client(vnode3));
1273        assert!(
1274            poll_fn(|cx| Poll::Ready(build_client3_future.as_mut().poll(cx)))
1275                .await
1276                .is_pending()
1277        );
1278        let mut client3;
1279        {
1280            {
1281                // commit epoch2
1282                let mut commit_future = pin!(
1283                    client1
1284                        .commit(
1285                            epoch2,
1286                            SinkMetadata {
1287                                metadata: Some(Metadata::Serialized(SerializedMetadata {
1288                                    metadata: metadata[1][0].clone(),
1289                                })),
1290                            },
1291                            None,
1292                        )
1293                        .map_err(Into::into)
1294                );
1295                assert!(
1296                    poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1297                        .await
1298                        .is_pending()
1299                );
1300                try_join(
1301                    commit_future,
1302                    client2.commit(
1303                        epoch2,
1304                        SinkMetadata {
1305                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1306                                metadata: metadata[1][1].clone(),
1307                            })),
1308                        },
1309                        None,
1310                    ),
1311                )
1312                .await
1313                .unwrap();
1314            }
1315
1316            client3 = {
1317                let (
1318                    (client3, init_epoch),
1319                    (update_vnode_bitmap_epoch1, update_vnode_bitmap_epoch2),
1320                ) = try_join(
1321                    build_client3_future,
1322                    try_join(
1323                        client1.update_vnode_bitmap(&vnode1),
1324                        client2.update_vnode_bitmap(&vnode2),
1325                    )
1326                    .map_err(Into::into),
1327                )
1328                .await
1329                .unwrap();
1330                assert_eq!(init_epoch, Some(epoch2));
1331                assert_eq!(update_vnode_bitmap_epoch1, epoch2);
1332                assert_eq!(update_vnode_bitmap_epoch2, epoch2);
1333                client3
1334            };
1335            let mut commit_future3 = pin!(client3.commit(
1336                epoch3,
1337                SinkMetadata {
1338                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1339                        metadata: metadata_scale_out[2].clone(),
1340                    })),
1341                },
1342                None,
1343            ));
1344            assert!(
1345                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1346                    .await
1347                    .is_pending()
1348            );
1349            let mut commit_future1 = pin!(client1.commit(
1350                epoch3,
1351                SinkMetadata {
1352                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1353                        metadata: metadata_scale_out[0].clone(),
1354                    })),
1355                },
1356                None,
1357            ));
1358            assert!(
1359                poll_fn(|cx| Poll::Ready(commit_future1.as_mut().poll(cx)))
1360                    .await
1361                    .is_pending()
1362            );
1363            assert!(
1364                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1365                    .await
1366                    .is_pending()
1367            );
1368            try_join(
1369                client2.commit(
1370                    epoch3,
1371                    SinkMetadata {
1372                        metadata: Some(Metadata::Serialized(SerializedMetadata {
1373                            metadata: metadata_scale_out[1].clone(),
1374                        })),
1375                    },
1376                    None,
1377                ),
1378                try_join(commit_future1, commit_future3),
1379            )
1380            .await
1381            .unwrap();
1382        }
1383
1384        let (vnode2, vnode3) = {
1385            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1386            (build_bitmap(first), build_bitmap(second))
1387        };
1388
1389        {
1390            let (_, (update_vnode_bitmap_epoch2, update_vnode_bitmap_epoch3)) = try_join(
1391                client1.stop(),
1392                try_join(
1393                    client2.update_vnode_bitmap(&vnode2),
1394                    client3.update_vnode_bitmap(&vnode3),
1395                ),
1396            )
1397            .await
1398            .unwrap();
1399            assert_eq!(update_vnode_bitmap_epoch2, epoch3);
1400            assert_eq!(update_vnode_bitmap_epoch3, epoch3);
1401        }
1402
1403        {
1404            let mut commit_future = pin!(
1405                client2
1406                    .commit(
1407                        epoch4,
1408                        SinkMetadata {
1409                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1410                                metadata: metadata_scale_in[0].clone(),
1411                            })),
1412                        },
1413                        None,
1414                    )
1415                    .map(|result| result.unwrap())
1416            );
1417            assert!(
1418                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1419                    .await
1420                    .is_pending()
1421            );
1422            join(
1423                commit_future,
1424                client3
1425                    .commit(
1426                        epoch4,
1427                        SinkMetadata {
1428                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1429                                metadata: metadata_scale_in[1].clone(),
1430                            })),
1431                        },
1432                        None,
1433                    )
1434                    .map(|result| result.unwrap()),
1435            )
1436            .await;
1437        }
1438    }
1439
1440    struct MockTwoPhaseCoordinator<
1441        P: FnMut(
1442            u64,
1443            Vec<SinkMetadata>,
1444            Option<PbSinkSchemaChange>,
1445        ) -> Result<Option<Vec<u8>>, SinkError>,
1446        CD: FnMut(u64, Vec<u8>) -> Result<(), SinkError>,
1447        CS: FnMut(u64, PbSinkSchemaChange) -> Result<(), SinkError>,
1448    > {
1449        pre_commit: P,
1450        commit_data: CD,
1451        commit_schema_change: CS,
1452    }
1453
1454    impl<
1455        P: FnMut(
1456                u64,
1457                Vec<SinkMetadata>,
1458                Option<PbSinkSchemaChange>,
1459            ) -> Result<Option<Vec<u8>>, SinkError>
1460            + Send
1461            + 'static,
1462        CD: FnMut(u64, Vec<u8>) -> Result<(), SinkError> + Send + 'static,
1463        CS: FnMut(u64, PbSinkSchemaChange) -> Result<(), SinkError> + Send + 'static,
1464    > MockTwoPhaseCoordinator<P, CD, CS>
1465    {
1466        fn new_coordinator(
1467            pre_commit: P,
1468            commit_data: CD,
1469            commit_schema_change: CS,
1470        ) -> SinkCommitCoordinator {
1471            SinkCommitCoordinator::TwoPhase(Box::new(MockTwoPhaseCoordinator {
1472                pre_commit,
1473                commit_data,
1474                commit_schema_change,
1475            }))
1476        }
1477    }
1478
1479    #[async_trait]
1480    impl<
1481        P: FnMut(
1482                u64,
1483                Vec<SinkMetadata>,
1484                Option<PbSinkSchemaChange>,
1485            ) -> Result<Option<Vec<u8>>, SinkError>
1486            + Send
1487            + 'static,
1488        CD: FnMut(u64, Vec<u8>) -> Result<(), SinkError> + Send + 'static,
1489        CS: FnMut(u64, PbSinkSchemaChange) -> Result<(), SinkError> + Send + 'static,
1490    > TwoPhaseCommitCoordinator for MockTwoPhaseCoordinator<P, CD, CS>
1491    {
1492        async fn init(&mut self) -> risingwave_connector::sink::Result<()> {
1493            Ok(())
1494        }
1495
1496        async fn pre_commit(
1497            &mut self,
1498            epoch: u64,
1499            metadata: Vec<SinkMetadata>,
1500            schema_change: Option<PbSinkSchemaChange>,
1501        ) -> risingwave_connector::sink::Result<Option<Vec<u8>>> {
1502            (self.pre_commit)(epoch, metadata, schema_change)
1503        }
1504
1505        async fn commit_data(
1506            &mut self,
1507            epoch: u64,
1508            commit_metadata: Vec<u8>,
1509        ) -> risingwave_connector::sink::Result<()> {
1510            (self.commit_data)(epoch, commit_metadata)
1511        }
1512
1513        async fn commit_schema_change(
1514            &mut self,
1515            epoch: u64,
1516            schema_change: PbSinkSchemaChange,
1517        ) -> risingwave_connector::sink::Result<()> {
1518            (self.commit_schema_change)(epoch, schema_change)
1519        }
1520
1521        async fn abort(&mut self, _epoch: u64, _commit_metadata: Vec<u8>) {
1522            tracing::debug!("abort called");
1523        }
1524    }
1525
1526    async fn prepare_db_backend() -> DatabaseConnection {
1527        let db: DatabaseConnection = Database::connect("sqlite::memory:").await.unwrap();
1528        let ddl = "
1529            CREATE TABLE IF NOT EXISTS pending_sink_state (
1530                sink_id i32 NOT NULL,
1531                epoch i64 NOT NULL,
1532                sink_state STRING NOT NULL,
1533                metadata BLOB NOT NULL,
1534                schema_change BLOB,
1535                PRIMARY KEY (sink_id, epoch)
1536            )
1537        ";
1538        db.execute(sea_orm::Statement::from_string(
1539            db.get_database_backend(),
1540            ddl.to_owned(),
1541        ))
1542        .await
1543        .unwrap();
1544        db
1545    }
1546
1547    async fn list_rows(
1548        db: &DatabaseConnection,
1549    ) -> Vec<(i32, i64, String, Vec<u8>, Option<PbSinkSchemaChange>)> {
1550        let sql =
1551            "SELECT sink_id, epoch, sink_state, metadata, schema_change FROM pending_sink_state";
1552        let rows = db
1553            .query_all(sea_orm::Statement::from_string(
1554                db.get_database_backend(),
1555                sql.to_owned(),
1556            ))
1557            .await
1558            .unwrap();
1559        rows.into_iter()
1560            .map(|row| {
1561                (
1562                    row.try_get("", "sink_id").unwrap(),
1563                    row.try_get("", "epoch").unwrap(),
1564                    row.try_get("", "sink_state").unwrap(),
1565                    row.try_get("", "metadata").unwrap(),
1566                    row.try_get::<Option<SinkSchemachange>>("", "schema_change")
1567                        .unwrap()
1568                        .map(|v| v.to_protobuf()),
1569                )
1570            })
1571            .collect()
1572    }
1573
1574    async fn set_epoch_aborted(db: &DatabaseConnection, sink_id: SinkId, epoch: u64) {
1575        let sql = format!(
1576            "UPDATE pending_sink_state SET sink_state = 'ABORTED' WHERE sink_id = {} AND epoch = {}",
1577            sink_id, epoch as i64
1578        );
1579        db.execute(sea_orm::Statement::from_string(
1580            db.get_database_backend(),
1581            sql,
1582        ))
1583        .await
1584        .unwrap();
1585    }
1586
1587    #[tokio::test]
1588    async fn test_init_response_skips_recovered_pending_epoch() {
1589        let db = prepare_db_backend().await;
1590
1591        let param = SinkParam {
1592            sink_id: SinkId::from(1),
1593            sink_name: "test".into(),
1594            properties: Default::default(),
1595            columns: vec![],
1596            downstream_pk: None,
1597            sink_type: SinkType::AppendOnly,
1598            ignore_delete: false,
1599            format_desc: None,
1600            db_name: "test".into(),
1601            sink_from_name: "test".into(),
1602        };
1603
1604        let epoch1 = 233;
1605        let metadata = vec![1u8, 2u8];
1606
1607        let sql = format!(
1608            "INSERT INTO pending_sink_state (sink_id, epoch, sink_state, metadata) VALUES ({}, {}, 'PENDING', x'0102')",
1609            param.sink_id, epoch1
1610        );
1611        db.execute(sea_orm::Statement::from_string(
1612            db.get_database_backend(),
1613            sql,
1614        ))
1615        .await
1616        .unwrap();
1617
1618        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1619        let build_bitmap = |indexes: &[usize]| {
1620            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1621            for i in indexes {
1622                builder.set(*i, true);
1623            }
1624            builder.finish()
1625        };
1626        let vnode = build_bitmap(&all_vnode);
1627
1628        let sender = Arc::new(tokio::sync::Mutex::new(None));
1629        let mock_subscriber: SinkCommittedEpochSubscriber = {
1630            let captured_sender = sender.clone();
1631            Arc::new(move |_sink_id: SinkId| {
1632                let (epoch_sender, receiver) = unbounded_channel();
1633                let captured_sender = captured_sender.clone();
1634                async move {
1635                    let mut guard = captured_sender.lock().await;
1636                    *guard = Some(epoch_sender);
1637                    Ok((epoch1, receiver))
1638                }
1639                .boxed()
1640            })
1641        };
1642
1643        let pre_commit_attempt = Arc::new(AtomicI32::new(0));
1644        let commit_attempt = Arc::new(AtomicI32::new(0));
1645        let (manager, (_join_handle, _stop_tx)) =
1646            SinkCoordinatorManager::start_worker_with_spawn_worker({
1647                let expected_param = param.clone();
1648                let db = db.clone();
1649                let metadata = metadata.clone();
1650                let pre_commit_attempt = pre_commit_attempt.clone();
1651                let commit_attempt = commit_attempt.clone();
1652                move |param, new_writer_rx| {
1653                    let expected_param = expected_param.clone();
1654                    let db = db.clone();
1655                    let metadata = metadata.clone();
1656                    let pre_commit_attempt = pre_commit_attempt.clone();
1657                    let commit_attempt = commit_attempt.clone();
1658                    tokio::spawn({
1659                        let subscriber = mock_subscriber.clone();
1660                        async move {
1661                            assert_eq!(param, expected_param);
1662                            CoordinatorWorker::execute_coordinator(
1663                                db,
1664                                param.clone(),
1665                                new_writer_rx,
1666                                MockTwoPhaseCoordinator::new_coordinator(
1667                                    move |_epoch, _metadata_list, _schema_change| {
1668                                        pre_commit_attempt.fetch_add(
1669                                            1,
1670                                            std::sync::atomic::Ordering::SeqCst,
1671                                        );
1672                                        Err(SinkError::Coordinator(anyhow!(
1673                                            "pre_commit should not be called for a known pending epoch"
1674                                        )))
1675                                    },
1676                                    move |epoch, commit_metadata| {
1677                                        assert_eq!(epoch, epoch1);
1678                                        assert_eq!(commit_metadata, metadata);
1679                                        commit_attempt
1680                                            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1681                                        Ok(())
1682                                    },
1683                                    move |_epoch, _schema_change| unreachable!(),
1684                                ),
1685                                subscriber.clone(),
1686                            )
1687                            .await;
1688                        }
1689                    })
1690                }
1691            });
1692
1693        let (_client, log_store_rewind_start_epoch) =
1694            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1695                Ok(tonic::Response::new(
1696                    manager
1697                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1698                        .await
1699                        .unwrap()
1700                        .boxed(),
1701                ))
1702            })
1703            .await
1704            .unwrap();
1705        assert_eq!(log_store_rewind_start_epoch, Some(epoch1));
1706
1707        for _ in 0..50 {
1708            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1709            let rows = list_rows(&db).await;
1710            if rows[0].2 == "COMMITTED" {
1711                break;
1712            }
1713        }
1714
1715        assert_eq!(
1716            pre_commit_attempt.load(std::sync::atomic::Ordering::SeqCst),
1717            0
1718        );
1719        assert_eq!(commit_attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
1720        let rows = list_rows(&db).await;
1721        assert_eq!(rows.len(), 1);
1722        assert_eq!(rows[0].1, epoch1 as i64);
1723        assert_eq!(rows[0].2, "COMMITTED");
1724    }
1725
1726    #[tokio::test]
1727    async fn test_pre_commit_failed() {
1728        let db = prepare_db_backend().await;
1729
1730        let param = SinkParam {
1731            sink_id: SinkId::from(1),
1732            sink_name: "test".into(),
1733            properties: Default::default(),
1734            columns: vec![],
1735            downstream_pk: None,
1736            sink_type: SinkType::AppendOnly,
1737            ignore_delete: false,
1738            format_desc: None,
1739            db_name: "test".into(),
1740            sink_from_name: "test".into(),
1741        };
1742
1743        let epoch1 = 233;
1744
1745        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1746        let build_bitmap = |indexes: &[usize]| {
1747            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1748            for i in indexes {
1749                builder.set(*i, true);
1750            }
1751            builder.finish()
1752        };
1753        let vnode = build_bitmap(&all_vnode);
1754
1755        let metadata = vec![1u8, 2u8];
1756        let sender = Arc::new(tokio::sync::Mutex::new(None));
1757        let mock_subscriber: SinkCommittedEpochSubscriber = {
1758            let captured_sender = sender.clone();
1759            Arc::new(move |_sink_id: SinkId| {
1760                let (sender, receiver) = unbounded_channel();
1761                let captured_sender = captured_sender.clone();
1762                async move {
1763                    let mut guard = captured_sender.lock().await;
1764                    *guard = Some(sender);
1765                    Ok((epoch1, receiver))
1766                }
1767                .boxed()
1768            })
1769        };
1770
1771        let (manager, (_join_handle, _stop_tx)) =
1772            SinkCoordinatorManager::start_worker_with_spawn_worker({
1773                let expected_param = param.clone();
1774                let db = db.clone();
1775                move |param, new_writer_rx| {
1776                    let expected_param = expected_param.clone();
1777                    let db = db.clone();
1778                    tokio::spawn({
1779                        let subscriber = mock_subscriber.clone();
1780                        async move {
1781                            // validate the start request
1782                            assert_eq!(param, expected_param);
1783                            CoordinatorWorker::execute_coordinator(
1784                                db,
1785                                param.clone(),
1786                                new_writer_rx,
1787                                MockTwoPhaseCoordinator::new_coordinator(
1788                                    move |_epoch, _metadata_list, _schema_change| {
1789                                        Err(SinkError::Coordinator(anyhow!("failed to pre commit")))
1790                                    },
1791                                    move |_epoch, _commit_metadata| unreachable!(),
1792                                    move |_epoch, _schema_change| unreachable!(),
1793                                ),
1794                                subscriber.clone(),
1795                            )
1796                            .await;
1797                        }
1798                    })
1799                }
1800            });
1801
1802        let build_client = |vnode| async {
1803            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1804                Ok(tonic::Response::new(
1805                    manager
1806                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1807                        .await
1808                        .unwrap()
1809                        .boxed(),
1810                ))
1811            })
1812            .await
1813            .unwrap()
1814            .0
1815        };
1816
1817        let mut client = build_client(vnode).await;
1818
1819        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1820        assert_eq!(aligned_epoch, 1);
1821
1822        let commit_result = client
1823            .commit(
1824                epoch1,
1825                SinkMetadata {
1826                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1827                        metadata: metadata.clone(),
1828                    })),
1829                },
1830                None,
1831            )
1832            .await;
1833        assert!(commit_result.is_err());
1834
1835        let rows = list_rows(&db).await;
1836        assert!(rows.is_empty());
1837    }
1838
1839    #[tokio::test]
1840    async fn test_waiting_on_checkpoint() {
1841        let db = prepare_db_backend().await;
1842
1843        let param = SinkParam {
1844            sink_id: SinkId::from(1),
1845            sink_name: "test".into(),
1846            properties: Default::default(),
1847            columns: vec![],
1848            downstream_pk: None,
1849            sink_type: SinkType::AppendOnly,
1850            ignore_delete: false,
1851            format_desc: None,
1852            db_name: "test".into(),
1853            sink_from_name: "test".into(),
1854        };
1855
1856        let epoch0 = 232;
1857        let epoch1 = 233;
1858
1859        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1860        let build_bitmap = |indexes: &[usize]| {
1861            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1862            for i in indexes {
1863                builder.set(*i, true);
1864            }
1865            builder.finish()
1866        };
1867        let vnode = build_bitmap(&all_vnode);
1868
1869        let metadata = vec![1u8, 2u8];
1870
1871        let sender = Arc::new(tokio::sync::Mutex::new(None));
1872        let mock_subscriber: SinkCommittedEpochSubscriber = {
1873            let captured_sender = sender.clone();
1874            Arc::new(move |_sink_id: SinkId| {
1875                let (sender, receiver) = unbounded_channel();
1876                let captured_sender = captured_sender.clone();
1877                async move {
1878                    let mut guard = captured_sender.lock().await;
1879                    *guard = Some(sender);
1880                    Ok((epoch0, receiver))
1881                }
1882                .boxed()
1883            })
1884        };
1885
1886        let (manager, (_join_handle, _stop_tx)) =
1887            SinkCoordinatorManager::start_worker_with_spawn_worker({
1888                let expected_param = param.clone();
1889                let metadata = metadata.clone();
1890                let db = db.clone();
1891                move |param, new_writer_rx| {
1892                    let metadata = metadata.clone();
1893                    let expected_param = expected_param.clone();
1894                    let db = db.clone();
1895                    tokio::spawn({
1896                        let subscriber = mock_subscriber.clone();
1897                        async move {
1898                            // validate the start request
1899                            assert_eq!(param, expected_param);
1900                            CoordinatorWorker::execute_coordinator(
1901                                db,
1902                                param.clone(),
1903                                new_writer_rx,
1904                                MockTwoPhaseCoordinator::new_coordinator(
1905                                    move |_epoch, metadata_list, _schema_change| {
1906                                        let metadata =
1907                                            Itertools::exactly_one(metadata_list.into_iter())
1908                                                .unwrap();
1909                                        Ok(match metadata.metadata {
1910                                            Some(Metadata::Serialized(SerializedMetadata {
1911                                                metadata,
1912                                            })) => Some(metadata),
1913                                            _ => unreachable!(),
1914                                        })
1915                                    },
1916                                    move |_epoch, commit_metadata| {
1917                                        assert_eq!(commit_metadata, metadata);
1918                                        Ok(())
1919                                    },
1920                                    move |_epoch, _schema_change| unreachable!(),
1921                                ),
1922                                subscriber.clone(),
1923                            )
1924                            .await;
1925                        }
1926                    })
1927                }
1928            });
1929
1930        let build_client = |vnode| async {
1931            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1932                Ok(tonic::Response::new(
1933                    manager
1934                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1935                        .await
1936                        .unwrap()
1937                        .boxed(),
1938                ))
1939            })
1940            .await
1941            .unwrap()
1942            .0
1943        };
1944
1945        let mut client = build_client(vnode).await;
1946
1947        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1948        assert_eq!(aligned_epoch, 1);
1949
1950        client
1951            .commit(
1952                epoch1,
1953                SinkMetadata {
1954                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1955                        metadata: metadata.clone(),
1956                    })),
1957                },
1958                None,
1959            )
1960            .await
1961            .unwrap();
1962
1963        {
1964            let rows = list_rows(&db).await;
1965            assert_eq!(rows.len(), 1);
1966            assert_eq!(rows[0].1, epoch1 as i64);
1967            assert_eq!(rows[0].2, "PENDING");
1968
1969            let guard = sender.lock().await;
1970            let sender = guard.as_ref().unwrap().clone();
1971            sender.send(233).unwrap();
1972        }
1973
1974        // wait max 5 seconds for the commit to be processed
1975        for _ in 0..50 {
1976            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1977            let rows = list_rows(&db).await;
1978            if rows[0].2 == "COMMITTED" {
1979                break;
1980            }
1981        }
1982
1983        {
1984            let rows = list_rows(&db).await;
1985            assert_eq!(rows.len(), 1);
1986            assert_eq!(rows[0].1, epoch1 as i64);
1987            assert_eq!(rows[0].2, "COMMITTED");
1988        }
1989    }
1990
1991    #[tokio::test]
1992    async fn test_commit_retry_loop() {
1993        let db = prepare_db_backend().await;
1994
1995        let param = SinkParam {
1996            sink_id: SinkId::from(1),
1997            sink_name: "test".into(),
1998            properties: Default::default(),
1999            columns: vec![],
2000            downstream_pk: None,
2001            sink_type: SinkType::AppendOnly,
2002            ignore_delete: false,
2003            format_desc: None,
2004            db_name: "test".into(),
2005            sink_from_name: "test".into(),
2006        };
2007
2008        let epoch1 = 233;
2009
2010        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
2011        let build_bitmap = |indexes: &[usize]| {
2012            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
2013            for i in indexes {
2014                builder.set(*i, true);
2015            }
2016            builder.finish()
2017        };
2018        let vnode = build_bitmap(&all_vnode);
2019
2020        let metadata = vec![1u8, 2u8];
2021        let sender = Arc::new(tokio::sync::Mutex::new(None));
2022        let mock_subscriber: SinkCommittedEpochSubscriber = {
2023            let captured_sender = sender.clone();
2024            Arc::new(move |_sink_id: SinkId| {
2025                let (sender, receiver) = unbounded_channel();
2026                let captured_sender = captured_sender.clone();
2027                async move {
2028                    let mut guard = captured_sender.lock().await;
2029                    *guard = Some(sender);
2030                    Ok((epoch1, receiver))
2031                }
2032                .boxed()
2033            })
2034        };
2035
2036        let commit_attempt = Arc::new(AtomicI32::new(0));
2037
2038        let (manager, (_join_handle, _stop_tx)) =
2039            SinkCoordinatorManager::start_worker_with_spawn_worker({
2040                let expected_param = param.clone();
2041                let metadata = metadata.clone();
2042                let db = db.clone();
2043                let commit_attempt = commit_attempt.clone();
2044                move |param, new_writer_rx| {
2045                    let metadata = metadata.clone();
2046                    let expected_param = expected_param.clone();
2047                    let db = db.clone();
2048                    let commit_attempt = commit_attempt.clone();
2049                    tokio::spawn({
2050                        let subscriber = mock_subscriber.clone();
2051                        async move {
2052                            // validate the start request
2053                            assert_eq!(param, expected_param);
2054                            CoordinatorWorker::execute_coordinator(
2055                                db,
2056                                param.clone(),
2057                                new_writer_rx,
2058                                MockTwoPhaseCoordinator::new_coordinator(
2059                                    move |_epoch, metadata_list, _schema_change| {
2060                                        let metadata =
2061                                            Itertools::exactly_one(metadata_list.into_iter())
2062                                                .unwrap();
2063                                        Ok(match metadata.metadata {
2064                                            Some(Metadata::Serialized(SerializedMetadata {
2065                                                metadata,
2066                                            })) => Some(metadata),
2067                                            _ => unreachable!(),
2068                                        })
2069                                    },
2070                                    move |_epoch, commit_metadata| {
2071                                        assert_eq!(commit_metadata, metadata);
2072                                        if commit_attempt
2073                                            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
2074                                            < 2
2075                                        {
2076                                            Err(SinkError::Coordinator(anyhow!("failed to commit")))
2077                                        } else {
2078                                            Ok(())
2079                                        }
2080                                    },
2081                                    move |_epoch, _schema_change| unreachable!(),
2082                                ),
2083                                subscriber.clone(),
2084                            )
2085                            .await;
2086                        }
2087                    })
2088                }
2089            });
2090
2091        let build_client = |vnode| async {
2092            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2093                Ok(tonic::Response::new(
2094                    manager
2095                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2096                        .await
2097                        .unwrap()
2098                        .boxed(),
2099                ))
2100            })
2101            .await
2102            .unwrap()
2103            .0
2104        };
2105
2106        let mut client = build_client(vnode).await;
2107
2108        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2109        assert_eq!(aligned_epoch, 1);
2110
2111        client
2112            .commit(
2113                epoch1,
2114                SinkMetadata {
2115                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2116                        metadata: metadata.clone(),
2117                    })),
2118                },
2119                None,
2120            )
2121            .await
2122            .unwrap();
2123
2124        // wait max 10 seconds for the commit to be processed
2125        for _ in 0..100 {
2126            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2127            let rows = list_rows(&db).await;
2128            if rows[0].2 == "COMMITTED" {
2129                break;
2130            }
2131        }
2132
2133        assert_eq!(commit_attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
2134
2135        {
2136            let rows = list_rows(&db).await;
2137            assert_eq!(rows.len(), 1);
2138            assert_eq!(rows[0].1, epoch1 as i64);
2139            assert_eq!(rows[0].2, "COMMITTED");
2140        }
2141    }
2142
2143    #[tokio::test]
2144    async fn test_aborted() {
2145        let db = prepare_db_backend().await;
2146
2147        let param = SinkParam {
2148            sink_id: SinkId::from(1),
2149            sink_name: "test".into(),
2150            properties: Default::default(),
2151            columns: vec![],
2152            downstream_pk: None,
2153            sink_type: SinkType::AppendOnly,
2154            ignore_delete: false,
2155            format_desc: None,
2156            db_name: "test".into(),
2157            sink_from_name: "test".into(),
2158        };
2159
2160        let epoch0 = 232;
2161        let epoch1 = 233;
2162
2163        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
2164        let build_bitmap = |indexes: &[usize]| {
2165            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
2166            for i in indexes {
2167                builder.set(*i, true);
2168            }
2169            builder.finish()
2170        };
2171        let vnode = build_bitmap(&all_vnode);
2172
2173        let metadata = vec![1u8, 2u8];
2174
2175        let sender = Arc::new(tokio::sync::Mutex::new(None));
2176        let mock_subscriber: SinkCommittedEpochSubscriber = {
2177            let captured_sender = sender.clone();
2178            Arc::new(move |_sink_id: SinkId| {
2179                let (sender, receiver) = unbounded_channel();
2180                let captured_sender = captured_sender.clone();
2181                async move {
2182                    let mut guard = captured_sender.lock().await;
2183                    *guard = Some(sender);
2184                    Ok((epoch0, receiver))
2185                }
2186                .boxed()
2187            })
2188        };
2189
2190        let (manager, (_join_handle, _stop_tx)) =
2191            SinkCoordinatorManager::start_worker_with_spawn_worker({
2192                let expected_param = param.clone();
2193                let metadata = metadata.clone();
2194                let db = db.clone();
2195                move |param, new_writer_rx| {
2196                    let metadata = metadata.clone();
2197                    let expected_param = expected_param.clone();
2198                    let db = db.clone();
2199                    tokio::spawn({
2200                        let subscriber = mock_subscriber.clone();
2201                        async move {
2202                            // validate the start request
2203                            assert_eq!(param, expected_param);
2204                            CoordinatorWorker::execute_coordinator(
2205                                db,
2206                                param.clone(),
2207                                new_writer_rx,
2208                                MockTwoPhaseCoordinator::new_coordinator(
2209                                    move |_epoch, metadata_list, _schema_change| {
2210                                        let metadata =
2211                                            Itertools::exactly_one(metadata_list.into_iter())
2212                                                .unwrap();
2213                                        Ok(match metadata.metadata {
2214                                            Some(Metadata::Serialized(SerializedMetadata {
2215                                                metadata,
2216                                            })) => Some(metadata),
2217                                            _ => unreachable!(),
2218                                        })
2219                                    },
2220                                    move |_epoch, commit_metadata| {
2221                                        assert_eq!(commit_metadata, metadata);
2222                                        Ok(())
2223                                    },
2224                                    move |_epoch, _schema_change| unreachable!(),
2225                                ),
2226                                subscriber.clone(),
2227                            )
2228                            .await;
2229                        }
2230                    })
2231                }
2232            });
2233
2234        let build_client = |vnode| async {
2235            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2236                Ok(tonic::Response::new(
2237                    manager
2238                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2239                        .await
2240                        .unwrap()
2241                        .boxed(),
2242                ))
2243            })
2244            .await
2245            .unwrap()
2246            .0
2247        };
2248
2249        let mut client = build_client(vnode.clone()).await;
2250
2251        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2252        assert_eq!(aligned_epoch, 1);
2253
2254        client
2255            .commit(
2256                epoch1,
2257                SinkMetadata {
2258                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2259                        metadata: metadata.clone(),
2260                    })),
2261                },
2262                None,
2263            )
2264            .await
2265            .unwrap();
2266
2267        manager.stop_sink_coordinator(vec![SinkId::from(1)]).await;
2268
2269        {
2270            let rows = list_rows(&db).await;
2271            assert_eq!(rows.len(), 1);
2272            assert_eq!(rows[0].1, epoch1 as i64);
2273            assert_eq!(rows[0].2, "PENDING");
2274
2275            set_epoch_aborted(&db, SinkId::from(1), epoch1).await;
2276            let rows = list_rows(&db).await;
2277            assert_eq!(rows.len(), 1);
2278            assert_eq!(rows[0].1, epoch1 as i64);
2279            assert_eq!(rows[0].2, "ABORTED");
2280        }
2281
2282        let mut client = build_client(vnode).await;
2283
2284        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2285        assert_eq!(aligned_epoch, 1);
2286
2287        {
2288            let rows = list_rows(&db).await;
2289            assert!(rows.is_empty());
2290        }
2291    }
2292
2293    #[tokio::test]
2294    async fn test_flush_when_reschedule() {
2295        let db = prepare_db_backend().await;
2296
2297        let param = SinkParam {
2298            sink_id: SinkId::from(1),
2299            sink_name: "test".into(),
2300            properties: Default::default(),
2301            columns: vec![],
2302            downstream_pk: None,
2303            sink_type: SinkType::AppendOnly,
2304            ignore_delete: false,
2305            format_desc: None,
2306            db_name: "test".into(),
2307            sink_from_name: "test".into(),
2308        };
2309
2310        let epoch0 = 232;
2311        let epoch1 = 233;
2312
2313        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
2314        let build_bitmap = |indexes: &[usize]| {
2315            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
2316            for i in indexes {
2317                builder.set(*i, true);
2318            }
2319            builder.finish()
2320        };
2321        let vnode = build_bitmap(&all_vnode);
2322
2323        let metadata = vec![1u8, 2u8];
2324        let schema_change = PbSinkSchemaChange {
2325            original_schema: vec![PbField {
2326                data_type: Some(PbDataType {
2327                    type_name: PbTypeName::Int32 as i32,
2328                    ..Default::default()
2329                }),
2330                name: "col_v1".into(),
2331            }],
2332            op: Some(SinkSchemachangeOp::AddColumns(PbSinkAddColumnsOp {
2333                fields: vec![PbField {
2334                    data_type: Some(PbDataType {
2335                        type_name: PbTypeName::Varchar as i32,
2336                        ..Default::default()
2337                    }),
2338                    name: "new_col".into(),
2339                }],
2340            })),
2341        };
2342
2343        let sender = Arc::new(tokio::sync::Mutex::new(None));
2344        let mock_subscriber: SinkCommittedEpochSubscriber = {
2345            let captured_sender = sender.clone();
2346            Arc::new(move |_sink_id: SinkId| {
2347                let (sender, receiver) = unbounded_channel();
2348                let captured_sender = captured_sender.clone();
2349                async move {
2350                    let mut guard = captured_sender.lock().await;
2351                    *guard = Some(sender);
2352                    Ok((epoch0, receiver))
2353                }
2354                .boxed()
2355            })
2356        };
2357
2358        let (manager, (_join_handle, _stop_tx)) =
2359            SinkCoordinatorManager::start_worker_with_spawn_worker({
2360                let expected_param = param.clone();
2361                let metadata = metadata.clone();
2362                let schema_change = schema_change.clone();
2363                let db = db.clone();
2364                move |param, new_writer_rx| {
2365                    let metadata = metadata.clone();
2366                    let schema_change_for_pre_commit = schema_change.clone();
2367                    let schema_change_for_commit = schema_change.clone();
2368                    let expected_param = expected_param.clone();
2369                    let db = db.clone();
2370                    tokio::spawn({
2371                        let subscriber = mock_subscriber.clone();
2372                        async move {
2373                            assert_eq!(param, expected_param);
2374                            CoordinatorWorker::execute_coordinator(
2375                                db,
2376                                param.clone(),
2377                                new_writer_rx,
2378                                MockTwoPhaseCoordinator::new_coordinator(
2379                                    move |_epoch, metadata_list, schema_change| {
2380                                        assert_eq!(
2381                                            schema_change,
2382                                            Some(schema_change_for_pre_commit.clone())
2383                                        );
2384                                        let metadata =
2385                                            Itertools::exactly_one(metadata_list.into_iter())
2386                                                .unwrap();
2387                                        Ok(match metadata.metadata {
2388                                            Some(Metadata::Serialized(SerializedMetadata {
2389                                                metadata,
2390                                            })) => Some(metadata),
2391                                            _ => unreachable!(),
2392                                        })
2393                                    },
2394                                    move |_epoch, commit_metadata| {
2395                                        assert_eq!(commit_metadata, metadata);
2396                                        Ok(())
2397                                    },
2398                                    move |_epoch, schema_change| {
2399                                        assert_eq!(schema_change, schema_change_for_commit.clone());
2400                                        Ok(())
2401                                    },
2402                                ),
2403                                subscriber.clone(),
2404                            )
2405                            .await;
2406                        }
2407                    })
2408                }
2409            });
2410
2411        let build_client = |vnode| async {
2412            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2413                Ok(tonic::Response::new(
2414                    manager
2415                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2416                        .await
2417                        .unwrap()
2418                        .boxed(),
2419                ))
2420            })
2421            .await
2422        };
2423
2424        let (mut client1, _) = build_client(vnode.clone()).await.unwrap();
2425
2426        let aligned_epoch = client1.align_initial_epoch(1).await.unwrap();
2427        assert_eq!(aligned_epoch, 1);
2428
2429        client1
2430            .commit(
2431                epoch1,
2432                SinkMetadata {
2433                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2434                        metadata: metadata.clone(),
2435                    })),
2436                },
2437                Some(schema_change.clone()),
2438            )
2439            .await
2440            .unwrap();
2441
2442        {
2443            let rows = list_rows(&db).await;
2444            assert_eq!(rows.len(), 1);
2445            assert_eq!(rows[0].1, epoch1 as i64);
2446            assert_eq!(rows[0].2, "PENDING");
2447            assert_eq!(rows[0].4, Some(schema_change.clone()));
2448        }
2449
2450        let mut build_client2_future = pin!(build_client(vnode.clone()));
2451        assert!(
2452            poll_fn(|cx| Poll::Ready(build_client2_future.as_mut().poll(cx)))
2453                .await
2454                .is_pending()
2455        );
2456
2457        client1.stop().await.unwrap();
2458
2459        assert!(
2460            poll_fn(|cx| Poll::Ready(build_client2_future.as_mut().poll(cx)))
2461                .await
2462                .is_pending()
2463        );
2464
2465        {
2466            let guard = sender.lock().await;
2467            let sender = guard.as_ref().unwrap().clone();
2468            sender.send(epoch1).unwrap();
2469        }
2470
2471        let (_, init_epoch) = build_client2_future.await.unwrap();
2472        assert_eq!(init_epoch, Some(epoch1));
2473
2474        {
2475            let rows = list_rows(&db).await;
2476            assert_eq!(rows.len(), 1);
2477            assert_eq!(rows[0].1, epoch1 as i64);
2478            assert_eq!(rows[0].2, "COMMITTED");
2479            assert_eq!(rows[0].4, Some(schema_change.clone()));
2480        }
2481    }
2482}