risingwave_meta/manager/sink_coordination/
manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use futures::future::{BoxFuture, Either, join_all, select};
21use futures::stream::FuturesUnordered;
22use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
23use risingwave_common::bitmap::Bitmap;
24use risingwave_connector::connector_common::IcebergSinkCompactionUpdate;
25use risingwave_connector::sink::catalog::SinkId;
26use risingwave_connector::sink::{SinkCommittedEpochSubscriber, SinkError, SinkParam};
27use risingwave_pb::connector_service::coordinate_request::Msg;
28use risingwave_pb::connector_service::{CoordinateRequest, CoordinateResponse, coordinate_request};
29use rw_futures_util::pending_on_none;
30use sea_orm::DatabaseConnection;
31use thiserror_ext::AsReport;
32use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
33use tokio::sync::oneshot::{Receiver, Sender, channel};
34use tokio::sync::{mpsc, oneshot};
35use tokio::task::{JoinError, JoinHandle};
36use tokio_stream::wrappers::UnboundedReceiverStream;
37use tonic::Status;
38use tracing::{debug, error, info, warn};
39
40use crate::hummock::HummockManagerRef;
41use crate::manager::MetadataManager;
42use crate::manager::sink_coordination::SinkWriterRequestStream;
43use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
44use crate::manager::sink_coordination::handle::SinkWriterCoordinationHandle;
45
46macro_rules! send_with_err_check {
47    ($tx:expr, $msg:expr) => {
48        if $tx.send($msg).is_err() {
49            error!("unable to send msg");
50        }
51    };
52}
53
54macro_rules! send_await_with_err_check {
55    ($tx:expr, $msg:expr) => {
56        if $tx.send($msg).await.is_err() {
57            error!("unable to send msg");
58        }
59    };
60}
61
62const BOUNDED_CHANNEL_SIZE: usize = 16;
63
64enum ManagerRequest {
65    NewSinkWriter(SinkWriterCoordinationHandle),
66    StopCoordinator {
67        finish_notifier: Sender<()>,
68        /// sink id to stop. When `None`, stop all sink coordinator
69        sink_ids: Option<Vec<SinkId>>,
70    },
71}
72
73#[derive(Clone)]
74pub struct SinkCoordinatorManager {
75    request_tx: mpsc::Sender<ManagerRequest>,
76}
77
78fn new_committed_epoch_subscriber(
79    hummock_manager: HummockManagerRef,
80    metadata_manager: MetadataManager,
81) -> SinkCommittedEpochSubscriber {
82    Arc::new(move |sink_id| {
83        let hummock_manager = hummock_manager.clone();
84        let metadata_manager = metadata_manager.clone();
85        async move {
86            let state_table_ids = metadata_manager
87                .get_sink_state_table_ids(sink_id)
88                .await
89                .map_err(SinkError::from)?;
90            let Some(table_id) = state_table_ids.first() else {
91                return Err(anyhow!("no state table id in sink: {}", sink_id).into());
92            };
93            hummock_manager
94                .subscribe_table_committed_epoch(*table_id)
95                .await
96                .map_err(SinkError::from)
97        }
98        .boxed()
99    })
100}
101
102impl SinkCoordinatorManager {
103    pub fn start_worker(
104        db: DatabaseConnection,
105        hummock_manager: HummockManagerRef,
106        metadata_manager: MetadataManager,
107        iceberg_compact_stat_sender: UnboundedSender<IcebergSinkCompactionUpdate>,
108    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
109        let subscriber = new_committed_epoch_subscriber(hummock_manager, metadata_manager);
110        Self::start_worker_with_spawn_worker(move |param, manager_request_stream| {
111            tokio::spawn(CoordinatorWorker::run(
112                param,
113                manager_request_stream,
114                db.clone(),
115                subscriber.clone(),
116                iceberg_compact_stat_sender.clone(),
117            ))
118        })
119    }
120
121    fn start_worker_with_spawn_worker(
122        spawn_coordinator_worker: impl SpawnCoordinatorFn,
123    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
124        let (request_tx, request_rx) = mpsc::channel(BOUNDED_CHANNEL_SIZE);
125        let (shutdown_tx, shutdown_rx) = channel();
126        let worker = ManagerWorker::new(request_rx, shutdown_rx);
127        let join_handle = tokio::spawn(worker.execute(spawn_coordinator_worker));
128        (
129            SinkCoordinatorManager { request_tx },
130            (join_handle, shutdown_tx),
131        )
132    }
133
134    pub async fn handle_new_request(
135        &self,
136        mut request_stream: SinkWriterRequestStream,
137    ) -> Result<impl Stream<Item = Result<CoordinateResponse, Status>> + use<>, Status> {
138        let (param, vnode_bitmap) = match request_stream.try_next().await? {
139            Some(CoordinateRequest {
140                msg:
141                    Some(Msg::StartRequest(coordinate_request::StartCoordinationRequest {
142                        param: Some(param),
143                        vnode_bitmap: Some(vnode_bitmap),
144                    })),
145            }) => (SinkParam::from_proto(param), Bitmap::from(&vnode_bitmap)),
146            msg => {
147                return Err(Status::invalid_argument(format!(
148                    "expected CoordinateRequest::StartRequest in the first request, get {:?}",
149                    msg
150                )));
151            }
152        };
153        let (response_tx, response_rx) = mpsc::unbounded_channel();
154        self.request_tx
155            .send(ManagerRequest::NewSinkWriter(
156                SinkWriterCoordinationHandle::new(request_stream, response_tx, param, vnode_bitmap),
157            ))
158            .await
159            .map_err(|_| {
160                Status::unavailable(
161                    "unable to send to sink manager worker. The worker may have stopped",
162                )
163            })?;
164
165        Ok(UnboundedReceiverStream::new(response_rx))
166    }
167
168    async fn stop_coordinator(&self, sink_ids: Option<Vec<SinkId>>) {
169        let (tx, rx) = channel();
170        send_await_with_err_check!(
171            self.request_tx,
172            ManagerRequest::StopCoordinator {
173                finish_notifier: tx,
174                sink_ids: sink_ids.clone(),
175            }
176        );
177        if rx.await.is_err() {
178            error!("fail to wait for resetting sink manager worker");
179        }
180        info!("successfully stop coordinator: {:?}", sink_ids);
181    }
182
183    pub async fn reset(&self) {
184        self.stop_coordinator(None).await;
185    }
186
187    pub async fn stop_sink_coordinator(&self, sink_ids: Vec<SinkId>) {
188        self.stop_coordinator(Some(sink_ids)).await;
189    }
190}
191
192struct CoordinatorWorkerHandle {
193    /// Sender to coordinator worker. Drop the sender as a stop signal
194    request_sender: Option<UnboundedSender<SinkWriterCoordinationHandle>>,
195    /// Notify when the coordinator worker stops
196    finish_notifiers: Vec<Sender<()>>,
197}
198
199struct ManagerWorker {
200    request_rx: mpsc::Receiver<ManagerRequest>,
201    // Make it option so that it can be polled with &mut SinkManagerWorker
202    shutdown_rx: Receiver<()>,
203
204    running_coordinator_worker_join_handles:
205        FuturesUnordered<BoxFuture<'static, (SinkId, Result<(), JoinError>)>>,
206    running_coordinator_worker: HashMap<SinkId, CoordinatorWorkerHandle>,
207}
208
209enum ManagerEvent {
210    NewRequest(ManagerRequest),
211    CoordinatorWorkerFinished {
212        sink_id: SinkId,
213        join_result: Result<(), JoinError>,
214    },
215}
216
217trait SpawnCoordinatorFn = FnMut(SinkParam, UnboundedReceiver<SinkWriterCoordinationHandle>) -> JoinHandle<()>
218    + Send
219    + 'static;
220
221impl ManagerWorker {
222    fn new(request_rx: mpsc::Receiver<ManagerRequest>, shutdown_rx: Receiver<()>) -> Self {
223        ManagerWorker {
224            request_rx,
225            shutdown_rx,
226            running_coordinator_worker_join_handles: Default::default(),
227            running_coordinator_worker: Default::default(),
228        }
229    }
230
231    async fn execute(mut self, mut spawn_coordinator_worker: impl SpawnCoordinatorFn) {
232        while let Some(event) = self.next_event().await {
233            match event {
234                ManagerEvent::NewRequest(request) => match request {
235                    ManagerRequest::NewSinkWriter(request) => {
236                        self.handle_new_sink_writer(request, &mut spawn_coordinator_worker)
237                    }
238                    ManagerRequest::StopCoordinator {
239                        finish_notifier,
240                        sink_ids,
241                    } => {
242                        if let Some(sink_ids) = sink_ids {
243                            let mut rxs = Vec::with_capacity(sink_ids.len());
244                            for sink_id in sink_ids {
245                                if let Some(worker_handle) =
246                                    self.running_coordinator_worker.get_mut(&sink_id)
247                                {
248                                    let (tx, rx) = oneshot::channel();
249                                    rxs.push(rx);
250                                    worker_handle.finish_notifiers.push(tx);
251                                    if let Some(sender) = worker_handle.request_sender.take() {
252                                        // drop the sender as a signal to notify the coordinator worker
253                                        // to stop
254                                        drop(sender);
255                                    }
256                                } else {
257                                    debug!(
258                                        "sink coordinator of {} is not running, skip it",
259                                        sink_id
260                                    );
261                                }
262                            }
263                            tokio::spawn(async move {
264                                let notify_res = join_all(rxs).await;
265                                for res in notify_res {
266                                    if let Err(e) = res {
267                                        error!(
268                                            "fail to wait for resetting sink manager worker: {}",
269                                            e.as_report()
270                                        );
271                                    }
272                                }
273                                send_with_err_check!(finish_notifier, ());
274                            });
275                        } else {
276                            self.clean_up().await;
277                            send_with_err_check!(finish_notifier, ());
278                        }
279                    }
280                },
281                ManagerEvent::CoordinatorWorkerFinished {
282                    sink_id,
283                    join_result,
284                } => self.handle_coordinator_finished(sink_id, join_result),
285            }
286        }
287        self.clean_up().await;
288        info!("sink manager worker exited");
289    }
290
291    async fn next_event(&mut self) -> Option<ManagerEvent> {
292        match select(
293            select(
294                pin!(self.request_rx.recv()),
295                pin!(pending_on_none(
296                    self.running_coordinator_worker_join_handles.next()
297                )),
298            ),
299            &mut self.shutdown_rx,
300        )
301        .await
302        {
303            Either::Left((either, _)) => match either {
304                Either::Left((Some(request), _)) => Some(ManagerEvent::NewRequest(request)),
305                Either::Left((None, _)) => None,
306                Either::Right(((sink_id, join_result), _)) => {
307                    Some(ManagerEvent::CoordinatorWorkerFinished {
308                        sink_id,
309                        join_result,
310                    })
311                }
312            },
313            Either::Right(_) => None,
314        }
315    }
316
317    async fn clean_up(&mut self) {
318        info!("sink manager worker start cleaning up");
319        for worker_handle in self.running_coordinator_worker.values_mut() {
320            if let Some(sender) = worker_handle.request_sender.take() {
321                // drop the sender to notify the coordinator worker to stop
322                drop(sender);
323            }
324        }
325        while let Some((sink_id, join_result)) =
326            self.running_coordinator_worker_join_handles.next().await
327        {
328            self.handle_coordinator_finished(sink_id, join_result);
329        }
330        info!("sink manager worker finished cleaning up");
331    }
332
333    fn handle_coordinator_finished(&mut self, sink_id: SinkId, join_result: Result<(), JoinError>) {
334        let worker_handle = self
335            .running_coordinator_worker
336            .remove(&sink_id)
337            .expect("finished coordinator should have an associated worker handle");
338        for finish_notifier in worker_handle.finish_notifiers {
339            send_with_err_check!(finish_notifier, ());
340        }
341        match join_result {
342            Ok(()) => {
343                info!(
344                    id = %sink_id,
345                    "sink coordinator has gracefully finished",
346                );
347            }
348            Err(err) => {
349                error!(
350                    id = %sink_id,
351                    error = %err.as_report(),
352                    "sink coordinator finished with error",
353                );
354            }
355        }
356    }
357
358    fn handle_new_sink_writer(
359        &mut self,
360        new_writer: SinkWriterCoordinationHandle,
361        spawn_coordinator_worker: &mut impl SpawnCoordinatorFn,
362    ) {
363        let param = new_writer.param();
364        let sink_id = param.sink_id;
365
366        let handle = self
367            .running_coordinator_worker
368            .entry(param.sink_id)
369            .or_insert_with(|| {
370                // Launch the coordinator worker task if it is the first
371                let (request_tx, request_rx) = unbounded_channel();
372                let join_handle = spawn_coordinator_worker(param.clone(), request_rx);
373                self.running_coordinator_worker_join_handles.push(
374                    join_handle
375                        .map(move |join_result| (sink_id, join_result))
376                        .boxed(),
377                );
378                CoordinatorWorkerHandle {
379                    request_sender: Some(request_tx),
380                    finish_notifiers: Vec::new(),
381                }
382            });
383
384        if let Some(sender) = handle.request_sender.as_mut() {
385            send_with_err_check!(sender, new_writer);
386        } else {
387            warn!(
388                "handle a new request while the sink coordinator is being stopped: {:?}",
389                param
390            );
391            new_writer.abort(Status::internal("the sink is being stopped"));
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use std::future::{Future, poll_fn};
399    use std::pin::pin;
400    use std::sync::Arc;
401    use std::sync::atomic::AtomicI32;
402    use std::task::Poll;
403
404    use anyhow::anyhow;
405    use async_trait::async_trait;
406    use futures::future::{join, try_join};
407    use futures::{FutureExt, StreamExt, TryFutureExt};
408    use itertools::Itertools;
409    use rand::seq::SliceRandom;
410    use risingwave_common::bitmap::BitmapBuilder;
411    use risingwave_common::catalog::Field;
412    use risingwave_common::hash::VirtualNode;
413    use risingwave_connector::sink::catalog::{SinkId, SinkType};
414    use risingwave_connector::sink::{
415        SinglePhaseCommitCoordinator, SinkCommitCoordinator, SinkError, SinkParam,
416        TwoPhaseCommitCoordinator,
417    };
418    use risingwave_pb::connector_service::SinkMetadata;
419    use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata};
420    use risingwave_rpc_client::CoordinatorStreamHandle;
421    use sea_orm::{ConnectionTrait, Database, DatabaseConnection};
422    use tokio::sync::mpsc::unbounded_channel;
423    use tokio_stream::wrappers::ReceiverStream;
424
425    use crate::manager::sink_coordination::SinkCoordinatorManager;
426    use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
427    use crate::manager::sink_coordination::manager::SinkCommittedEpochSubscriber;
428
429    struct MockSinglePhaseCoordinator<
430        C,
431        F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>,
432    > {
433        context: C,
434        f: F,
435    }
436
437    impl<
438        C: Send + 'static,
439        F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send + 'static,
440    > MockSinglePhaseCoordinator<C, F>
441    {
442        fn new_coordinator(context: C, f: F) -> SinkCommitCoordinator {
443            SinkCommitCoordinator::SinglePhase(Box::new(MockSinglePhaseCoordinator { context, f }))
444        }
445    }
446
447    #[async_trait]
448    impl<C: Send, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send>
449        SinglePhaseCommitCoordinator for MockSinglePhaseCoordinator<C, F>
450    {
451        async fn init(&mut self) -> risingwave_connector::sink::Result<()> {
452            Ok(())
453        }
454
455        async fn commit(
456            &mut self,
457            epoch: u64,
458            metadata: Vec<SinkMetadata>,
459            _add_columns: Option<Vec<Field>>,
460        ) -> risingwave_connector::sink::Result<()> {
461            (self.f)(epoch, metadata, &mut self.context)
462        }
463    }
464
465    #[tokio::test]
466    async fn test_basic() {
467        let param = SinkParam {
468            sink_id: SinkId::from(1),
469            sink_name: "test".into(),
470            properties: Default::default(),
471            columns: vec![],
472            downstream_pk: None,
473            sink_type: SinkType::AppendOnly,
474            format_desc: None,
475            db_name: "test".into(),
476            sink_from_name: "test".into(),
477        };
478
479        let epoch0 = 232;
480        let epoch1 = 233;
481        let epoch2 = 234;
482
483        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
484        all_vnode.shuffle(&mut rand::rng());
485        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
486        let build_bitmap = |indexes: &[usize]| {
487            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
488            for i in indexes {
489                builder.set(*i, true);
490            }
491            builder.finish()
492        };
493        let vnode1 = build_bitmap(first);
494        let vnode2 = build_bitmap(second);
495
496        let metadata = [
497            [vec![1u8, 2u8], vec![3u8, 4u8]],
498            [vec![5u8, 6u8], vec![7u8, 8u8]],
499        ];
500        let sender = Arc::new(tokio::sync::Mutex::new(None));
501        let mock_subscriber: SinkCommittedEpochSubscriber = {
502            let captured_sender = sender.clone();
503            Arc::new(move |_sink_id: SinkId| {
504                let (sender, receiver) = unbounded_channel();
505                let captured_sender = captured_sender.clone();
506                async move {
507                    let mut guard = captured_sender.lock().await;
508                    *guard = Some(sender);
509                    Ok((1, receiver))
510                }
511                .boxed()
512            })
513        };
514
515        let (manager, (_join_handle, _stop_tx)) =
516            SinkCoordinatorManager::start_worker_with_spawn_worker({
517                let expected_param = param.clone();
518                let metadata = metadata.clone();
519                move |param, new_writer_rx| {
520                    let metadata = metadata.clone();
521                    let expected_param = expected_param.clone();
522                    tokio::spawn({
523                        let subscriber = mock_subscriber.clone();
524                        async move {
525                            // validate the start request
526                            assert_eq!(param, expected_param);
527                            CoordinatorWorker::execute_coordinator(
528                                DatabaseConnection::Disconnected,
529                                param.clone(),
530                                new_writer_rx,
531                                MockSinglePhaseCoordinator::new_coordinator(
532                                    0,
533                                    move |epoch, metadata_list, count: &mut usize| {
534                                        *count += 1;
535                                        let mut metadata_list =
536                                            metadata_list
537                                                .into_iter()
538                                                .map(|metadata| match metadata {
539                                                    SinkMetadata {
540                                                        metadata:
541                                                            Some(Metadata::Serialized(
542                                                                SerializedMetadata { metadata },
543                                                            )),
544                                                    } => metadata,
545                                                    _ => unreachable!(),
546                                                })
547                                                .collect_vec();
548                                        metadata_list.sort();
549                                        match *count {
550                                            1 => {
551                                                assert_eq!(epoch, epoch1);
552                                                assert_eq!(2, metadata_list.len());
553                                                assert_eq!(metadata[0][0], metadata_list[0]);
554                                                assert_eq!(metadata[0][1], metadata_list[1]);
555                                            }
556                                            2 => {
557                                                assert_eq!(epoch, epoch2);
558                                                assert_eq!(2, metadata_list.len());
559                                                assert_eq!(metadata[1][0], metadata_list[0]);
560                                                assert_eq!(metadata[1][1], metadata_list[1]);
561                                            }
562                                            _ => unreachable!(),
563                                        }
564                                        Ok(())
565                                    },
566                                ),
567                                subscriber.clone(),
568                            )
569                            .await;
570                        }
571                    })
572                }
573            });
574
575        let build_client = |vnode| async {
576            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
577                Ok(tonic::Response::new(
578                    manager
579                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
580                        .await
581                        .unwrap()
582                        .boxed(),
583                ))
584            })
585            .await
586            .unwrap()
587            .0
588        };
589
590        let (mut client1, mut client2) =
591            join(build_client(vnode1), pin!(build_client(vnode2))).await;
592
593        let (aligned_epoch1, aligned_epoch2) = try_join(
594            client1.align_initial_epoch(epoch0),
595            client2.align_initial_epoch(epoch1),
596        )
597        .await
598        .unwrap();
599        assert_eq!(aligned_epoch1, epoch1);
600        assert_eq!(aligned_epoch2, epoch1);
601
602        {
603            // commit epoch1
604            let mut commit_future = pin!(
605                client2
606                    .commit(
607                        epoch1,
608                        SinkMetadata {
609                            metadata: Some(Metadata::Serialized(SerializedMetadata {
610                                metadata: metadata[0][1].clone(),
611                            })),
612                        },
613                        None,
614                    )
615                    .map(|result| result.unwrap())
616            );
617            assert!(
618                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
619                    .await
620                    .is_pending()
621            );
622            join(
623                commit_future,
624                client1
625                    .commit(
626                        epoch1,
627                        SinkMetadata {
628                            metadata: Some(Metadata::Serialized(SerializedMetadata {
629                                metadata: metadata[0][0].clone(),
630                            })),
631                        },
632                        None,
633                    )
634                    .map(|result| result.unwrap()),
635            )
636            .await;
637        }
638
639        // commit epoch2
640        let mut commit_future = pin!(
641            client1
642                .commit(
643                    epoch2,
644                    SinkMetadata {
645                        metadata: Some(Metadata::Serialized(SerializedMetadata {
646                            metadata: metadata[1][0].clone(),
647                        })),
648                    },
649                    None,
650                )
651                .map(|result| result.unwrap())
652        );
653        assert!(
654            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
655                .await
656                .is_pending()
657        );
658        join(
659            commit_future,
660            client2
661                .commit(
662                    epoch2,
663                    SinkMetadata {
664                        metadata: Some(Metadata::Serialized(SerializedMetadata {
665                            metadata: metadata[1][1].clone(),
666                        })),
667                    },
668                    None,
669                )
670                .map(|result| result.unwrap()),
671        )
672        .await;
673    }
674
675    #[tokio::test]
676    async fn test_single_writer() {
677        let param = SinkParam {
678            sink_id: SinkId::from(1),
679            sink_name: "test".into(),
680            properties: Default::default(),
681            columns: vec![],
682            downstream_pk: None,
683            sink_type: SinkType::AppendOnly,
684            format_desc: None,
685            db_name: "test".into(),
686            sink_from_name: "test".into(),
687        };
688
689        let epoch1 = 233;
690        let epoch2 = 234;
691
692        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
693        let build_bitmap = |indexes: &[usize]| {
694            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
695            for i in indexes {
696                builder.set(*i, true);
697            }
698            builder.finish()
699        };
700        let vnode = build_bitmap(&all_vnode);
701
702        let metadata = [vec![1u8, 2u8], vec![3u8, 4u8]];
703        let sender = Arc::new(tokio::sync::Mutex::new(None));
704        let mock_subscriber: SinkCommittedEpochSubscriber = {
705            let captured_sender = sender.clone();
706            Arc::new(move |_sink_id: SinkId| {
707                let (sender, receiver) = unbounded_channel();
708                let captured_sender = captured_sender.clone();
709                async move {
710                    let mut guard = captured_sender.lock().await;
711                    *guard = Some(sender);
712                    Ok((1, receiver))
713                }
714                .boxed()
715            })
716        };
717        let (manager, (_join_handle, _stop_tx)) =
718            SinkCoordinatorManager::start_worker_with_spawn_worker({
719                let expected_param = param.clone();
720                let metadata = metadata.clone();
721                move |param, new_writer_rx| {
722                    let metadata = metadata.clone();
723                    let expected_param = expected_param.clone();
724                    tokio::spawn({
725                        let subscriber = mock_subscriber.clone();
726                        async move {
727                            // validate the start request
728                            assert_eq!(param, expected_param);
729                            CoordinatorWorker::execute_coordinator(
730                                DatabaseConnection::Disconnected,
731                                param.clone(),
732                                new_writer_rx,
733                                MockSinglePhaseCoordinator::new_coordinator(
734                                    0,
735                                    move |epoch, metadata_list, count: &mut usize| {
736                                        *count += 1;
737                                        let mut metadata_list =
738                                            metadata_list
739                                                .into_iter()
740                                                .map(|metadata| match metadata {
741                                                    SinkMetadata {
742                                                        metadata:
743                                                            Some(Metadata::Serialized(
744                                                                SerializedMetadata { metadata },
745                                                            )),
746                                                    } => metadata,
747                                                    _ => unreachable!(),
748                                                })
749                                                .collect_vec();
750                                        metadata_list.sort();
751                                        match *count {
752                                            1 => {
753                                                assert_eq!(epoch, epoch1);
754                                                assert_eq!(1, metadata_list.len());
755                                                assert_eq!(metadata[0], metadata_list[0]);
756                                            }
757                                            2 => {
758                                                assert_eq!(epoch, epoch2);
759                                                assert_eq!(1, metadata_list.len());
760                                                assert_eq!(metadata[1], metadata_list[0]);
761                                            }
762                                            _ => unreachable!(),
763                                        }
764                                        Ok(())
765                                    },
766                                ),
767                                subscriber.clone(),
768                            )
769                            .await;
770                        }
771                    })
772                }
773            });
774
775        let build_client = |vnode| async {
776            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
777                Ok(tonic::Response::new(
778                    manager
779                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
780                        .await
781                        .unwrap()
782                        .boxed(),
783                ))
784            })
785            .await
786            .unwrap()
787            .0
788        };
789
790        let mut client = build_client(vnode).await;
791
792        let aligned_epoch = client.align_initial_epoch(epoch1).await.unwrap();
793        assert_eq!(aligned_epoch, epoch1);
794
795        client
796            .commit(
797                epoch1,
798                SinkMetadata {
799                    metadata: Some(Metadata::Serialized(SerializedMetadata {
800                        metadata: metadata[0].clone(),
801                    })),
802                },
803                None,
804            )
805            .await
806            .unwrap();
807
808        client
809            .commit(
810                epoch2,
811                SinkMetadata {
812                    metadata: Some(Metadata::Serialized(SerializedMetadata {
813                        metadata: metadata[1].clone(),
814                    })),
815                },
816                None,
817            )
818            .await
819            .unwrap();
820    }
821
822    #[tokio::test]
823    async fn test_partial_commit() {
824        let param = SinkParam {
825            sink_id: SinkId::from(1),
826            sink_name: "test".into(),
827            properties: Default::default(),
828            columns: vec![],
829            downstream_pk: None,
830            sink_type: SinkType::AppendOnly,
831            format_desc: None,
832            db_name: "test".into(),
833            sink_from_name: "test".into(),
834        };
835
836        let epoch = 233;
837
838        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
839        all_vnode.shuffle(&mut rand::rng());
840        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
841        let build_bitmap = |indexes: &[usize]| {
842            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
843            for i in indexes {
844                builder.set(*i, true);
845            }
846            builder.finish()
847        };
848        let vnode1 = build_bitmap(first);
849        let vnode2 = build_bitmap(second);
850
851        let sender = Arc::new(tokio::sync::Mutex::new(None));
852        let mock_subscriber: SinkCommittedEpochSubscriber = {
853            let captured_sender = sender.clone();
854            Arc::new(move |_sink_id: SinkId| {
855                let (sender, receiver) = unbounded_channel();
856                let captured_sender = captured_sender.clone();
857                async move {
858                    let mut guard = captured_sender.lock().await;
859                    *guard = Some(sender);
860                    Ok((1, receiver))
861                }
862                .boxed()
863            })
864        };
865        let (manager, (_join_handle, _stop_tx)) =
866            SinkCoordinatorManager::start_worker_with_spawn_worker({
867                let expected_param = param.clone();
868                move |param, new_writer_rx| {
869                    let expected_param = expected_param.clone();
870                    tokio::spawn({
871                        let subscriber = mock_subscriber.clone();
872                        async move {
873                            // validate the start request
874                            assert_eq!(param, expected_param);
875                            CoordinatorWorker::execute_coordinator(
876                                DatabaseConnection::Disconnected,
877                                param,
878                                new_writer_rx,
879                                MockSinglePhaseCoordinator::new_coordinator(
880                                    (),
881                                    |_, _, _| unreachable!(),
882                                ),
883                                subscriber.clone(),
884                            )
885                            .await;
886                        }
887                    })
888                }
889            });
890
891        let build_client = |vnode| async {
892            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
893                Ok(tonic::Response::new(
894                    manager
895                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
896                        .await
897                        .unwrap()
898                        .boxed(),
899                ))
900            })
901            .await
902            .unwrap()
903            .0
904        };
905
906        let (mut client1, client2) = join(build_client(vnode1), build_client(vnode2)).await;
907
908        // commit epoch
909        let mut commit_future = pin!(client1.commit(
910            epoch,
911            SinkMetadata {
912                metadata: Some(Metadata::Serialized(SerializedMetadata {
913                    metadata: vec![],
914                })),
915            },
916            None,
917        ));
918        assert!(
919            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
920                .await
921                .is_pending()
922        );
923        drop(client2);
924        assert!(commit_future.await.is_err());
925    }
926
927    #[tokio::test]
928    async fn test_fail_commit() {
929        let param = SinkParam {
930            sink_id: SinkId::from(1),
931            sink_name: "test".into(),
932            properties: Default::default(),
933            columns: vec![],
934            downstream_pk: None,
935            sink_type: SinkType::AppendOnly,
936            format_desc: None,
937            db_name: "test".into(),
938            sink_from_name: "test".into(),
939        };
940
941        let epoch = 233;
942
943        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
944        all_vnode.shuffle(&mut rand::rng());
945        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
946        let build_bitmap = |indexes: &[usize]| {
947            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
948            for i in indexes {
949                builder.set(*i, true);
950            }
951            builder.finish()
952        };
953        let vnode1 = build_bitmap(first);
954        let vnode2 = build_bitmap(second);
955        let sender = Arc::new(tokio::sync::Mutex::new(None));
956        let mock_subscriber: SinkCommittedEpochSubscriber = {
957            let captured_sender = sender.clone();
958            Arc::new(move |_sink_id: SinkId| {
959                let (sender, receiver) = unbounded_channel();
960                let captured_sender = captured_sender.clone();
961                async move {
962                    let mut guard = captured_sender.lock().await;
963                    *guard = Some(sender);
964                    Ok((1, receiver))
965                }
966                .boxed()
967            })
968        };
969        let (manager, (_join_handle, _stop_tx)) =
970            SinkCoordinatorManager::start_worker_with_spawn_worker({
971                let expected_param = param.clone();
972                move |param, new_writer_rx| {
973                    let expected_param = expected_param.clone();
974                    tokio::spawn({
975                        let subscriber = mock_subscriber.clone();
976                        {
977                            async move {
978                                // validate the start request
979                                assert_eq!(param, expected_param);
980                                CoordinatorWorker::execute_coordinator(
981                                    DatabaseConnection::Disconnected,
982                                    param,
983                                    new_writer_rx,
984                                    MockSinglePhaseCoordinator::new_coordinator((), |_, _, _| {
985                                        Err(SinkError::Coordinator(anyhow!("failed to commit")))
986                                    }),
987                                    subscriber.clone(),
988                                )
989                                .await;
990                            }
991                        }
992                    })
993                }
994            });
995
996        let build_client = |vnode| async {
997            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
998                Ok(tonic::Response::new(
999                    manager
1000                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1001                        .await
1002                        .unwrap()
1003                        .boxed(),
1004                ))
1005            })
1006            .await
1007            .unwrap()
1008            .0
1009        };
1010
1011        let (mut client1, mut client2) = join(build_client(vnode1), build_client(vnode2)).await;
1012
1013        // commit epoch
1014        let mut commit_future = pin!(client1.commit(
1015            epoch,
1016            SinkMetadata {
1017                metadata: Some(Metadata::Serialized(SerializedMetadata {
1018                    metadata: vec![],
1019                })),
1020            },
1021            None,
1022        ));
1023        assert!(
1024            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1025                .await
1026                .is_pending()
1027        );
1028        let (result1, result2) = join(
1029            commit_future,
1030            client2.commit(
1031                epoch,
1032                SinkMetadata {
1033                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1034                        metadata: vec![],
1035                    })),
1036                },
1037                None,
1038            ),
1039        )
1040        .await;
1041        assert!(result1.is_err());
1042        assert!(result2.is_err());
1043    }
1044
1045    #[tokio::test]
1046    async fn test_update_vnode_bitmap() {
1047        let param = SinkParam {
1048            sink_id: SinkId::from(1),
1049            sink_name: "test".into(),
1050            properties: Default::default(),
1051            columns: vec![],
1052            downstream_pk: None,
1053            sink_type: SinkType::AppendOnly,
1054            format_desc: None,
1055            db_name: "test".into(),
1056            sink_from_name: "test".into(),
1057        };
1058
1059        let epoch1 = 233;
1060        let epoch2 = 234;
1061        let epoch3 = 235;
1062        let epoch4 = 236;
1063
1064        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1065        all_vnode.shuffle(&mut rand::rng());
1066        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
1067        let build_bitmap = |indexes: &[usize]| {
1068            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1069            for i in indexes {
1070                builder.set(*i, true);
1071            }
1072            builder.finish()
1073        };
1074        let vnode1 = build_bitmap(first);
1075        let vnode2 = build_bitmap(second);
1076
1077        let metadata = [
1078            [vec![1u8, 2u8], vec![3u8, 4u8]],
1079            [vec![5u8, 6u8], vec![7u8, 8u8]],
1080        ];
1081
1082        let metadata_scale_out = [vec![9u8, 10u8], vec![11u8, 12u8], vec![13u8, 14u8]];
1083        let metadata_scale_in = [vec![13u8, 14u8], vec![15u8, 16u8]];
1084        let sender = Arc::new(tokio::sync::Mutex::new(None));
1085        let mock_subscriber: SinkCommittedEpochSubscriber = {
1086            let captured_sender = sender.clone();
1087            Arc::new(move |_sink_id: SinkId| {
1088                let (sender, receiver) = unbounded_channel();
1089                let captured_sender = captured_sender.clone();
1090                async move {
1091                    let mut guard = captured_sender.lock().await;
1092                    *guard = Some(sender);
1093                    Ok((1, receiver))
1094                }
1095                .boxed()
1096            })
1097        };
1098        let (manager, (_join_handle, _stop_tx)) =
1099            SinkCoordinatorManager::start_worker_with_spawn_worker({
1100                let expected_param = param.clone();
1101                let metadata = metadata.clone();
1102                let metadata_scale_out = metadata_scale_out.clone();
1103                let metadata_scale_in = metadata_scale_in.clone();
1104                move |param, new_writer_rx| {
1105                    let metadata = metadata.clone();
1106                    let metadata_scale_out = metadata_scale_out.clone();
1107                    let metadata_scale_in = metadata_scale_in.clone();
1108                    let expected_param = expected_param.clone();
1109                    tokio::spawn({
1110                        let subscriber = mock_subscriber.clone();
1111                        async move {
1112                            // validate the start request
1113                            assert_eq!(param, expected_param);
1114                            CoordinatorWorker::execute_coordinator(
1115                                DatabaseConnection::Disconnected,
1116                                param.clone(),
1117                                new_writer_rx,
1118                                MockSinglePhaseCoordinator::new_coordinator(
1119                                    0,
1120                                    move |epoch, metadata_list, count: &mut usize| {
1121                                        *count += 1;
1122                                        let mut metadata_list =
1123                                            metadata_list
1124                                                .into_iter()
1125                                                .map(|metadata| match metadata {
1126                                                    SinkMetadata {
1127                                                        metadata:
1128                                                            Some(Metadata::Serialized(
1129                                                                SerializedMetadata { metadata },
1130                                                            )),
1131                                                    } => metadata,
1132                                                    _ => unreachable!(),
1133                                                })
1134                                                .collect_vec();
1135                                        metadata_list.sort();
1136                                        let (expected_epoch, expected_metadata_list) = match *count
1137                                        {
1138                                            1 => (epoch1, metadata[0].as_slice()),
1139                                            2 => (epoch2, metadata[1].as_slice()),
1140                                            3 => (epoch3, metadata_scale_out.as_slice()),
1141                                            4 => (epoch4, metadata_scale_in.as_slice()),
1142                                            _ => unreachable!(),
1143                                        };
1144                                        assert_eq!(expected_epoch, epoch);
1145                                        assert_eq!(expected_metadata_list, &metadata_list);
1146                                        Ok(())
1147                                    },
1148                                ),
1149                                subscriber.clone(),
1150                            )
1151                            .await;
1152                        }
1153                    })
1154                }
1155            });
1156
1157        let build_client = |vnode| async {
1158            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1159                Ok(tonic::Response::new(
1160                    manager
1161                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1162                        .await
1163                        .unwrap()
1164                        .boxed(),
1165                ))
1166            })
1167            .await
1168        };
1169
1170        let ((mut client1, _), (mut client2, _)) =
1171            try_join(build_client(vnode1), pin!(build_client(vnode2)))
1172                .await
1173                .unwrap();
1174
1175        let (aligned_epoch1, aligned_epoch2) = try_join(
1176            client1.align_initial_epoch(epoch1),
1177            client2.align_initial_epoch(epoch1),
1178        )
1179        .await
1180        .unwrap();
1181        assert_eq!(aligned_epoch1, epoch1);
1182        assert_eq!(aligned_epoch2, epoch1);
1183
1184        {
1185            // commit epoch1
1186            let mut commit_future = pin!(
1187                client2
1188                    .commit(
1189                        epoch1,
1190                        SinkMetadata {
1191                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1192                                metadata: metadata[0][1].clone(),
1193                            })),
1194                        },
1195                        None,
1196                    )
1197                    .map(|result| result.unwrap())
1198            );
1199            assert!(
1200                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1201                    .await
1202                    .is_pending()
1203            );
1204            join(
1205                commit_future,
1206                client1
1207                    .commit(
1208                        epoch1,
1209                        SinkMetadata {
1210                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1211                                metadata: metadata[0][0].clone(),
1212                            })),
1213                        },
1214                        None,
1215                    )
1216                    .map(|result| result.unwrap()),
1217            )
1218            .await;
1219        }
1220
1221        let (vnode1, vnode2, vnode3) = {
1222            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1223            let (second, third) = second.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1224            (
1225                build_bitmap(first),
1226                build_bitmap(second),
1227                build_bitmap(third),
1228            )
1229        };
1230
1231        let mut build_client3_future = pin!(build_client(vnode3));
1232        assert!(
1233            poll_fn(|cx| Poll::Ready(build_client3_future.as_mut().poll(cx)))
1234                .await
1235                .is_pending()
1236        );
1237        let mut client3;
1238        {
1239            {
1240                // commit epoch2
1241                let mut commit_future = pin!(
1242                    client1
1243                        .commit(
1244                            epoch2,
1245                            SinkMetadata {
1246                                metadata: Some(Metadata::Serialized(SerializedMetadata {
1247                                    metadata: metadata[1][0].clone(),
1248                                })),
1249                            },
1250                            None,
1251                        )
1252                        .map_err(Into::into)
1253                );
1254                assert!(
1255                    poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1256                        .await
1257                        .is_pending()
1258                );
1259                try_join(
1260                    commit_future,
1261                    client2.commit(
1262                        epoch2,
1263                        SinkMetadata {
1264                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1265                                metadata: metadata[1][1].clone(),
1266                            })),
1267                        },
1268                        None,
1269                    ),
1270                )
1271                .await
1272                .unwrap();
1273            }
1274
1275            client3 = {
1276                let (
1277                    (client3, init_epoch),
1278                    (update_vnode_bitmap_epoch1, update_vnode_bitmap_epoch2),
1279                ) = try_join(
1280                    build_client3_future,
1281                    try_join(
1282                        client1.update_vnode_bitmap(&vnode1),
1283                        client2.update_vnode_bitmap(&vnode2),
1284                    )
1285                    .map_err(Into::into),
1286                )
1287                .await
1288                .unwrap();
1289                assert_eq!(init_epoch, Some(epoch2));
1290                assert_eq!(update_vnode_bitmap_epoch1, epoch2);
1291                assert_eq!(update_vnode_bitmap_epoch2, epoch2);
1292                client3
1293            };
1294            let mut commit_future3 = pin!(client3.commit(
1295                epoch3,
1296                SinkMetadata {
1297                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1298                        metadata: metadata_scale_out[2].clone(),
1299                    })),
1300                },
1301                None,
1302            ));
1303            assert!(
1304                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1305                    .await
1306                    .is_pending()
1307            );
1308            let mut commit_future1 = pin!(client1.commit(
1309                epoch3,
1310                SinkMetadata {
1311                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1312                        metadata: metadata_scale_out[0].clone(),
1313                    })),
1314                },
1315                None,
1316            ));
1317            assert!(
1318                poll_fn(|cx| Poll::Ready(commit_future1.as_mut().poll(cx)))
1319                    .await
1320                    .is_pending()
1321            );
1322            assert!(
1323                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1324                    .await
1325                    .is_pending()
1326            );
1327            try_join(
1328                client2.commit(
1329                    epoch3,
1330                    SinkMetadata {
1331                        metadata: Some(Metadata::Serialized(SerializedMetadata {
1332                            metadata: metadata_scale_out[1].clone(),
1333                        })),
1334                    },
1335                    None,
1336                ),
1337                try_join(commit_future1, commit_future3),
1338            )
1339            .await
1340            .unwrap();
1341        }
1342
1343        let (vnode2, vnode3) = {
1344            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1345            (build_bitmap(first), build_bitmap(second))
1346        };
1347
1348        {
1349            let (_, (update_vnode_bitmap_epoch2, update_vnode_bitmap_epoch3)) = try_join(
1350                client1.stop(),
1351                try_join(
1352                    client2.update_vnode_bitmap(&vnode2),
1353                    client3.update_vnode_bitmap(&vnode3),
1354                ),
1355            )
1356            .await
1357            .unwrap();
1358            assert_eq!(update_vnode_bitmap_epoch2, epoch3);
1359            assert_eq!(update_vnode_bitmap_epoch3, epoch3);
1360        }
1361
1362        {
1363            let mut commit_future = pin!(
1364                client2
1365                    .commit(
1366                        epoch4,
1367                        SinkMetadata {
1368                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1369                                metadata: metadata_scale_in[0].clone(),
1370                            })),
1371                        },
1372                        None,
1373                    )
1374                    .map(|result| result.unwrap())
1375            );
1376            assert!(
1377                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1378                    .await
1379                    .is_pending()
1380            );
1381            join(
1382                commit_future,
1383                client3
1384                    .commit(
1385                        epoch4,
1386                        SinkMetadata {
1387                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1388                                metadata: metadata_scale_in[1].clone(),
1389                            })),
1390                        },
1391                        None,
1392                    )
1393                    .map(|result| result.unwrap()),
1394            )
1395            .await;
1396        }
1397    }
1398
1399    struct MockTwoPhaseCoordinator<
1400        P: FnMut(u64, Vec<SinkMetadata>) -> Result<Vec<u8>, SinkError>,
1401        C: FnMut(u64, Vec<u8>) -> Result<(), SinkError>,
1402    > {
1403        pre_commit: P,
1404        commit: C,
1405    }
1406
1407    impl<
1408        P: FnMut(u64, Vec<SinkMetadata>) -> Result<Vec<u8>, SinkError> + Send + 'static,
1409        C: FnMut(u64, Vec<u8>) -> Result<(), SinkError> + Send + 'static,
1410    > MockTwoPhaseCoordinator<P, C>
1411    {
1412        fn new_coordinator(pre_commit: P, commit: C) -> SinkCommitCoordinator {
1413            SinkCommitCoordinator::TwoPhase(Box::new(MockTwoPhaseCoordinator {
1414                pre_commit,
1415                commit,
1416            }))
1417        }
1418    }
1419
1420    #[async_trait]
1421    impl<
1422        P: FnMut(u64, Vec<SinkMetadata>) -> Result<Vec<u8>, SinkError> + Send + 'static,
1423        C: FnMut(u64, Vec<u8>) -> Result<(), SinkError> + Send + 'static,
1424    > TwoPhaseCommitCoordinator for MockTwoPhaseCoordinator<P, C>
1425    {
1426        async fn init(&mut self) -> risingwave_connector::sink::Result<()> {
1427            Ok(())
1428        }
1429
1430        async fn pre_commit(
1431            &mut self,
1432            epoch: u64,
1433            metadata: Vec<SinkMetadata>,
1434            _add_columns: Option<Vec<Field>>,
1435        ) -> risingwave_connector::sink::Result<Vec<u8>> {
1436            (self.pre_commit)(epoch, metadata)
1437        }
1438
1439        async fn commit(
1440            &mut self,
1441            epoch: u64,
1442            commit_metadata: Vec<u8>,
1443        ) -> risingwave_connector::sink::Result<()> {
1444            (self.commit)(epoch, commit_metadata)
1445        }
1446
1447        async fn abort(&mut self, _epoch: u64, _commit_metadata: Vec<u8>) {
1448            tracing::debug!("abort called");
1449        }
1450    }
1451
1452    async fn prepare_db_backend() -> DatabaseConnection {
1453        let db: DatabaseConnection = Database::connect("sqlite::memory:").await.unwrap();
1454        let ddl = "
1455            CREATE TABLE IF NOT EXISTS pending_sink_state (
1456                sink_id i32 NOT NULL,
1457                epoch i64 NOT NULL,
1458                sink_state STRING NOT NULL,
1459                metadata BLOB NOT NULL,
1460                PRIMARY KEY (sink_id, epoch)
1461            )
1462        ";
1463        db.execute(sea_orm::Statement::from_string(
1464            db.get_database_backend(),
1465            ddl.to_owned(),
1466        ))
1467        .await
1468        .unwrap();
1469        db
1470    }
1471
1472    async fn list_rows(db: &DatabaseConnection) -> Vec<(i32, i64, String, Vec<u8>)> {
1473        let sql = "SELECT sink_id, epoch, sink_state, metadata FROM pending_sink_state";
1474        let rows = db
1475            .query_all(sea_orm::Statement::from_string(
1476                db.get_database_backend(),
1477                sql.to_owned(),
1478            ))
1479            .await
1480            .unwrap();
1481        rows.into_iter()
1482            .map(|row| {
1483                (
1484                    row.try_get("", "sink_id").unwrap(),
1485                    row.try_get("", "epoch").unwrap(),
1486                    row.try_get("", "sink_state").unwrap(),
1487                    row.try_get("", "metadata").unwrap(),
1488                )
1489            })
1490            .collect()
1491    }
1492
1493    async fn set_epoch_aborted(db: &DatabaseConnection, sink_id: SinkId, epoch: u64) {
1494        let sql = format!(
1495            "UPDATE pending_sink_state SET sink_state = 'ABORTED' WHERE sink_id = {} AND epoch = {}",
1496            sink_id, epoch as i64
1497        );
1498        db.execute(sea_orm::Statement::from_string(
1499            db.get_database_backend(),
1500            sql,
1501        ))
1502        .await
1503        .unwrap();
1504    }
1505
1506    #[tokio::test]
1507    async fn test_pre_commit_failed() {
1508        let db = prepare_db_backend().await;
1509
1510        let param = SinkParam {
1511            sink_id: SinkId::from(1),
1512            sink_name: "test".into(),
1513            properties: Default::default(),
1514            columns: vec![],
1515            downstream_pk: None,
1516            sink_type: SinkType::AppendOnly,
1517            format_desc: None,
1518            db_name: "test".into(),
1519            sink_from_name: "test".into(),
1520        };
1521
1522        let epoch1 = 233;
1523
1524        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1525        let build_bitmap = |indexes: &[usize]| {
1526            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1527            for i in indexes {
1528                builder.set(*i, true);
1529            }
1530            builder.finish()
1531        };
1532        let vnode = build_bitmap(&all_vnode);
1533
1534        let metadata = vec![1u8, 2u8];
1535        let sender = Arc::new(tokio::sync::Mutex::new(None));
1536        let mock_subscriber: SinkCommittedEpochSubscriber = {
1537            let captured_sender = sender.clone();
1538            Arc::new(move |_sink_id: SinkId| {
1539                let (sender, receiver) = unbounded_channel();
1540                let captured_sender = captured_sender.clone();
1541                async move {
1542                    let mut guard = captured_sender.lock().await;
1543                    *guard = Some(sender);
1544                    Ok((epoch1, receiver))
1545                }
1546                .boxed()
1547            })
1548        };
1549
1550        let (manager, (_join_handle, _stop_tx)) =
1551            SinkCoordinatorManager::start_worker_with_spawn_worker({
1552                let expected_param = param.clone();
1553                let db = db.clone();
1554                move |param, new_writer_rx| {
1555                    let expected_param = expected_param.clone();
1556                    let db = db.clone();
1557                    tokio::spawn({
1558                        let subscriber = mock_subscriber.clone();
1559                        async move {
1560                            // validate the start request
1561                            assert_eq!(param, expected_param);
1562                            CoordinatorWorker::execute_coordinator(
1563                                db,
1564                                param.clone(),
1565                                new_writer_rx,
1566                                MockTwoPhaseCoordinator::new_coordinator(
1567                                    move |_epoch, _metadata_list| {
1568                                        Err(SinkError::Coordinator(anyhow!("failed to pre commit")))
1569                                    },
1570                                    move |_epoch, _commit_metadata| unreachable!(),
1571                                ),
1572                                subscriber.clone(),
1573                            )
1574                            .await;
1575                        }
1576                    })
1577                }
1578            });
1579
1580        let build_client = |vnode| async {
1581            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1582                Ok(tonic::Response::new(
1583                    manager
1584                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1585                        .await
1586                        .unwrap()
1587                        .boxed(),
1588                ))
1589            })
1590            .await
1591            .unwrap()
1592            .0
1593        };
1594
1595        let mut client = build_client(vnode).await;
1596
1597        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1598        assert_eq!(aligned_epoch, 1);
1599
1600        let commit_result = client
1601            .commit(
1602                epoch1,
1603                SinkMetadata {
1604                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1605                        metadata: metadata.clone(),
1606                    })),
1607                },
1608                None,
1609            )
1610            .await;
1611        assert!(commit_result.is_err());
1612
1613        let rows = list_rows(&db).await;
1614        assert!(rows.is_empty());
1615    }
1616
1617    #[tokio::test]
1618    async fn test_waiting_on_checkpoint() {
1619        let db = prepare_db_backend().await;
1620
1621        let param = SinkParam {
1622            sink_id: SinkId::from(1),
1623            sink_name: "test".into(),
1624            properties: Default::default(),
1625            columns: vec![],
1626            downstream_pk: None,
1627            sink_type: SinkType::AppendOnly,
1628            format_desc: None,
1629            db_name: "test".into(),
1630            sink_from_name: "test".into(),
1631        };
1632
1633        let epoch0 = 232;
1634        let epoch1 = 233;
1635
1636        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1637        let build_bitmap = |indexes: &[usize]| {
1638            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1639            for i in indexes {
1640                builder.set(*i, true);
1641            }
1642            builder.finish()
1643        };
1644        let vnode = build_bitmap(&all_vnode);
1645
1646        let metadata = vec![1u8, 2u8];
1647
1648        let sender = Arc::new(tokio::sync::Mutex::new(None));
1649        let mock_subscriber: SinkCommittedEpochSubscriber = {
1650            let captured_sender = sender.clone();
1651            Arc::new(move |_sink_id: SinkId| {
1652                let (sender, receiver) = unbounded_channel();
1653                let captured_sender = captured_sender.clone();
1654                async move {
1655                    let mut guard = captured_sender.lock().await;
1656                    *guard = Some(sender);
1657                    Ok((epoch0, receiver))
1658                }
1659                .boxed()
1660            })
1661        };
1662
1663        let (manager, (_join_handle, _stop_tx)) =
1664            SinkCoordinatorManager::start_worker_with_spawn_worker({
1665                let expected_param = param.clone();
1666                let metadata = metadata.clone();
1667                let db = db.clone();
1668                move |param, new_writer_rx| {
1669                    let metadata = metadata.clone();
1670                    let expected_param = expected_param.clone();
1671                    let db = db.clone();
1672                    tokio::spawn({
1673                        let subscriber = mock_subscriber.clone();
1674                        async move {
1675                            // validate the start request
1676                            assert_eq!(param, expected_param);
1677                            CoordinatorWorker::execute_coordinator(
1678                                db,
1679                                param.clone(),
1680                                new_writer_rx,
1681                                MockTwoPhaseCoordinator::new_coordinator(
1682                                    move |_epoch, metadata_list| {
1683                                        let metadata =
1684                                            metadata_list.into_iter().exactly_one().unwrap();
1685                                        Ok(match metadata.metadata {
1686                                            Some(Metadata::Serialized(SerializedMetadata {
1687                                                metadata,
1688                                            })) => metadata,
1689                                            _ => unreachable!(),
1690                                        })
1691                                    },
1692                                    move |_epoch, commit_metadata| {
1693                                        assert_eq!(commit_metadata, metadata);
1694                                        Ok(())
1695                                    },
1696                                ),
1697                                subscriber.clone(),
1698                            )
1699                            .await;
1700                        }
1701                    })
1702                }
1703            });
1704
1705        let build_client = |vnode| async {
1706            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1707                Ok(tonic::Response::new(
1708                    manager
1709                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1710                        .await
1711                        .unwrap()
1712                        .boxed(),
1713                ))
1714            })
1715            .await
1716            .unwrap()
1717            .0
1718        };
1719
1720        let mut client = build_client(vnode).await;
1721
1722        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1723        assert_eq!(aligned_epoch, 1);
1724
1725        client
1726            .commit(
1727                epoch1,
1728                SinkMetadata {
1729                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1730                        metadata: metadata.clone(),
1731                    })),
1732                },
1733                None,
1734            )
1735            .await
1736            .unwrap();
1737
1738        {
1739            let rows = list_rows(&db).await;
1740            assert_eq!(rows.len(), 1);
1741            assert_eq!(rows[0].1, epoch1 as i64);
1742            assert_eq!(rows[0].2, "PENDING");
1743
1744            let guard = sender.lock().await;
1745            let sender = guard.as_ref().unwrap().clone();
1746            sender.send(233).unwrap();
1747        }
1748
1749        // wait max 5 seconds for the commit to be processed
1750        for _ in 0..50 {
1751            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1752            let rows = list_rows(&db).await;
1753            if rows[0].2 == "COMMITTED" {
1754                break;
1755            }
1756        }
1757
1758        {
1759            let rows = list_rows(&db).await;
1760            assert_eq!(rows.len(), 1);
1761            assert_eq!(rows[0].1, epoch1 as i64);
1762            assert_eq!(rows[0].2, "COMMITTED");
1763        }
1764    }
1765
1766    #[tokio::test]
1767    async fn test_commit_retry_loop() {
1768        let db = prepare_db_backend().await;
1769
1770        let param = SinkParam {
1771            sink_id: SinkId::from(1),
1772            sink_name: "test".into(),
1773            properties: Default::default(),
1774            columns: vec![],
1775            downstream_pk: None,
1776            sink_type: SinkType::AppendOnly,
1777            format_desc: None,
1778            db_name: "test".into(),
1779            sink_from_name: "test".into(),
1780        };
1781
1782        let epoch1 = 233;
1783
1784        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1785        let build_bitmap = |indexes: &[usize]| {
1786            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1787            for i in indexes {
1788                builder.set(*i, true);
1789            }
1790            builder.finish()
1791        };
1792        let vnode = build_bitmap(&all_vnode);
1793
1794        let metadata = vec![1u8, 2u8];
1795        let sender = Arc::new(tokio::sync::Mutex::new(None));
1796        let mock_subscriber: SinkCommittedEpochSubscriber = {
1797            let captured_sender = sender.clone();
1798            Arc::new(move |_sink_id: SinkId| {
1799                let (sender, receiver) = unbounded_channel();
1800                let captured_sender = captured_sender.clone();
1801                async move {
1802                    let mut guard = captured_sender.lock().await;
1803                    *guard = Some(sender);
1804                    Ok((epoch1, receiver))
1805                }
1806                .boxed()
1807            })
1808        };
1809
1810        let commit_attempt = Arc::new(AtomicI32::new(0));
1811
1812        let (manager, (_join_handle, _stop_tx)) =
1813            SinkCoordinatorManager::start_worker_with_spawn_worker({
1814                let expected_param = param.clone();
1815                let metadata = metadata.clone();
1816                let db = db.clone();
1817                let commit_attempt = commit_attempt.clone();
1818                move |param, new_writer_rx| {
1819                    let metadata = metadata.clone();
1820                    let expected_param = expected_param.clone();
1821                    let db = db.clone();
1822                    let commit_attempt = commit_attempt.clone();
1823                    tokio::spawn({
1824                        let subscriber = mock_subscriber.clone();
1825                        async move {
1826                            // validate the start request
1827                            assert_eq!(param, expected_param);
1828                            CoordinatorWorker::execute_coordinator(
1829                                db,
1830                                param.clone(),
1831                                new_writer_rx,
1832                                MockTwoPhaseCoordinator::new_coordinator(
1833                                    move |_epoch, metadata_list| {
1834                                        let metadata =
1835                                            metadata_list.into_iter().exactly_one().unwrap();
1836                                        Ok(match metadata.metadata {
1837                                            Some(Metadata::Serialized(SerializedMetadata {
1838                                                metadata,
1839                                            })) => metadata,
1840                                            _ => unreachable!(),
1841                                        })
1842                                    },
1843                                    move |_epoch, commit_metadata| {
1844                                        assert_eq!(commit_metadata, metadata);
1845                                        if commit_attempt
1846                                            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
1847                                            < 2
1848                                        {
1849                                            Err(SinkError::Coordinator(anyhow!("failed to commit")))
1850                                        } else {
1851                                            Ok(())
1852                                        }
1853                                    },
1854                                ),
1855                                subscriber.clone(),
1856                            )
1857                            .await;
1858                        }
1859                    })
1860                }
1861            });
1862
1863        let build_client = |vnode| async {
1864            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1865                Ok(tonic::Response::new(
1866                    manager
1867                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1868                        .await
1869                        .unwrap()
1870                        .boxed(),
1871                ))
1872            })
1873            .await
1874            .unwrap()
1875            .0
1876        };
1877
1878        let mut client = build_client(vnode).await;
1879
1880        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1881        assert_eq!(aligned_epoch, 1);
1882
1883        client
1884            .commit(
1885                epoch1,
1886                SinkMetadata {
1887                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1888                        metadata: metadata.clone(),
1889                    })),
1890                },
1891                None,
1892            )
1893            .await
1894            .unwrap();
1895
1896        // wait max 10 seconds for the commit to be processed
1897        for _ in 0..100 {
1898            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1899            let rows = list_rows(&db).await;
1900            if rows[0].2 == "COMMITTED" {
1901                break;
1902            }
1903        }
1904
1905        assert_eq!(commit_attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
1906
1907        {
1908            let rows = list_rows(&db).await;
1909            assert_eq!(rows.len(), 1);
1910            assert_eq!(rows[0].1, epoch1 as i64);
1911            assert_eq!(rows[0].2, "COMMITTED");
1912        }
1913    }
1914
1915    #[tokio::test]
1916    async fn test_aborted() {
1917        let db = prepare_db_backend().await;
1918
1919        let param = SinkParam {
1920            sink_id: SinkId::from(1),
1921            sink_name: "test".into(),
1922            properties: Default::default(),
1923            columns: vec![],
1924            downstream_pk: None,
1925            sink_type: SinkType::AppendOnly,
1926            format_desc: None,
1927            db_name: "test".into(),
1928            sink_from_name: "test".into(),
1929        };
1930
1931        let epoch0 = 232;
1932        let epoch1 = 233;
1933
1934        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1935        let build_bitmap = |indexes: &[usize]| {
1936            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1937            for i in indexes {
1938                builder.set(*i, true);
1939            }
1940            builder.finish()
1941        };
1942        let vnode = build_bitmap(&all_vnode);
1943
1944        let metadata = vec![1u8, 2u8];
1945
1946        let sender = Arc::new(tokio::sync::Mutex::new(None));
1947        let mock_subscriber: SinkCommittedEpochSubscriber = {
1948            let captured_sender = sender.clone();
1949            Arc::new(move |_sink_id: SinkId| {
1950                let (sender, receiver) = unbounded_channel();
1951                let captured_sender = captured_sender.clone();
1952                async move {
1953                    let mut guard = captured_sender.lock().await;
1954                    *guard = Some(sender);
1955                    Ok((epoch0, receiver))
1956                }
1957                .boxed()
1958            })
1959        };
1960
1961        let (manager, (_join_handle, _stop_tx)) =
1962            SinkCoordinatorManager::start_worker_with_spawn_worker({
1963                let expected_param = param.clone();
1964                let metadata = metadata.clone();
1965                let db = db.clone();
1966                move |param, new_writer_rx| {
1967                    let metadata = metadata.clone();
1968                    let expected_param = expected_param.clone();
1969                    let db = db.clone();
1970                    tokio::spawn({
1971                        let subscriber = mock_subscriber.clone();
1972                        async move {
1973                            // validate the start request
1974                            assert_eq!(param, expected_param);
1975                            CoordinatorWorker::execute_coordinator(
1976                                db,
1977                                param.clone(),
1978                                new_writer_rx,
1979                                MockTwoPhaseCoordinator::new_coordinator(
1980                                    move |_epoch, metadata_list| {
1981                                        let metadata =
1982                                            metadata_list.into_iter().exactly_one().unwrap();
1983                                        Ok(match metadata.metadata {
1984                                            Some(Metadata::Serialized(SerializedMetadata {
1985                                                metadata,
1986                                            })) => metadata,
1987                                            _ => unreachable!(),
1988                                        })
1989                                    },
1990                                    move |_epoch, commit_metadata| {
1991                                        assert_eq!(commit_metadata, metadata);
1992                                        Ok(())
1993                                    },
1994                                ),
1995                                subscriber.clone(),
1996                            )
1997                            .await;
1998                        }
1999                    })
2000                }
2001            });
2002
2003        let build_client = |vnode| async {
2004            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2005                Ok(tonic::Response::new(
2006                    manager
2007                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2008                        .await
2009                        .unwrap()
2010                        .boxed(),
2011                ))
2012            })
2013            .await
2014            .unwrap()
2015            .0
2016        };
2017
2018        let mut client = build_client(vnode.clone()).await;
2019
2020        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2021        assert_eq!(aligned_epoch, 1);
2022
2023        client
2024            .commit(
2025                epoch1,
2026                SinkMetadata {
2027                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2028                        metadata: metadata.clone(),
2029                    })),
2030                },
2031                None,
2032            )
2033            .await
2034            .unwrap();
2035
2036        manager.stop_sink_coordinator(vec![SinkId::from(1)]).await;
2037
2038        {
2039            let rows = list_rows(&db).await;
2040            assert_eq!(rows.len(), 1);
2041            assert_eq!(rows[0].1, epoch1 as i64);
2042            assert_eq!(rows[0].2, "PENDING");
2043
2044            set_epoch_aborted(&db, SinkId::from(1), epoch1).await;
2045            let rows = list_rows(&db).await;
2046            assert_eq!(rows.len(), 1);
2047            assert_eq!(rows[0].1, epoch1 as i64);
2048            assert_eq!(rows[0].2, "ABORTED");
2049        }
2050
2051        let mut client = build_client(vnode).await;
2052
2053        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2054        assert_eq!(aligned_epoch, 1);
2055
2056        {
2057            let rows = list_rows(&db).await;
2058            assert!(rows.is_empty());
2059        }
2060    }
2061}