risingwave_meta/manager/sink_coordination/
manager.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use futures::future::{BoxFuture, Either, join_all, select};
21use futures::stream::FuturesUnordered;
22use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
23use risingwave_common::bitmap::Bitmap;
24use risingwave_connector::connector_common::IcebergSinkCompactionUpdate;
25use risingwave_connector::sink::catalog::SinkId;
26use risingwave_connector::sink::{SinkCommittedEpochSubscriber, SinkError, SinkParam};
27use risingwave_pb::connector_service::coordinate_request::Msg;
28use risingwave_pb::connector_service::{CoordinateRequest, CoordinateResponse, coordinate_request};
29use rw_futures_util::pending_on_none;
30use sea_orm::DatabaseConnection;
31use thiserror_ext::AsReport;
32use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
33use tokio::sync::oneshot::{Receiver, Sender, channel};
34use tokio::sync::{mpsc, oneshot};
35use tokio::task::{JoinError, JoinHandle};
36use tokio_stream::wrappers::UnboundedReceiverStream;
37use tonic::Status;
38use tracing::{debug, error, info, warn};
39
40use crate::hummock::HummockManagerRef;
41use crate::manager::MetadataManager;
42use crate::manager::sink_coordination::SinkWriterRequestStream;
43use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
44use crate::manager::sink_coordination::handle::SinkWriterCoordinationHandle;
45
46macro_rules! send_with_err_check {
47    ($tx:expr, $msg:expr) => {
48        if $tx.send($msg).is_err() {
49            error!("unable to send msg");
50        }
51    };
52}
53
54macro_rules! send_await_with_err_check {
55    ($tx:expr, $msg:expr) => {
56        if $tx.send($msg).await.is_err() {
57            error!("unable to send msg");
58        }
59    };
60}
61
62const BOUNDED_CHANNEL_SIZE: usize = 16;
63
64enum ManagerRequest {
65    NewSinkWriter(SinkWriterCoordinationHandle),
66    StopCoordinator {
67        finish_notifier: Sender<()>,
68        /// sink id to stop. When `None`, stop all sink coordinator
69        sink_ids: Option<Vec<SinkId>>,
70    },
71}
72
73#[derive(Clone)]
74pub struct SinkCoordinatorManager {
75    request_tx: mpsc::Sender<ManagerRequest>,
76}
77
78fn new_committed_epoch_subscriber(
79    hummock_manager: HummockManagerRef,
80    metadata_manager: MetadataManager,
81) -> SinkCommittedEpochSubscriber {
82    Arc::new(move |sink_id| {
83        let hummock_manager = hummock_manager.clone();
84        let metadata_manager = metadata_manager.clone();
85        async move {
86            let state_table_ids = metadata_manager
87                .get_sink_state_table_ids(sink_id)
88                .await
89                .map_err(SinkError::from)?;
90            let Some(table_id) = state_table_ids.first() else {
91                return Err(anyhow!("no state table id in sink: {}", sink_id).into());
92            };
93            hummock_manager
94                .subscribe_table_committed_epoch(*table_id)
95                .await
96                .map_err(SinkError::from)
97        }
98        .boxed()
99    })
100}
101
102impl SinkCoordinatorManager {
103    pub fn start_worker(
104        db: DatabaseConnection,
105        hummock_manager: HummockManagerRef,
106        metadata_manager: MetadataManager,
107        iceberg_compact_stat_sender: UnboundedSender<IcebergSinkCompactionUpdate>,
108        await_tree_reg: await_tree::Registry,
109    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
110        let subscriber = new_committed_epoch_subscriber(hummock_manager, metadata_manager);
111        Self::start_worker_with_spawn_worker({
112            move |param, manager_request_stream| {
113                let sink_id = param.sink_id;
114                let fut = CoordinatorWorker::run(
115                    param,
116                    manager_request_stream,
117                    db.clone(),
118                    subscriber.clone(),
119                    iceberg_compact_stat_sender.clone(),
120                );
121                tokio::spawn(
122                    await_tree_reg
123                        .register_derived_root(format!("Sink Coordinator {sink_id}"))
124                        .instrument(fut),
125                )
126            }
127        })
128    }
129
130    fn start_worker_with_spawn_worker(
131        spawn_coordinator_worker: impl SpawnCoordinatorFn,
132    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
133        let (request_tx, request_rx) = mpsc::channel(BOUNDED_CHANNEL_SIZE);
134        let (shutdown_tx, shutdown_rx) = channel();
135        let worker = ManagerWorker::new(request_rx, shutdown_rx);
136        let join_handle = tokio::spawn(worker.execute(spawn_coordinator_worker));
137        (
138            SinkCoordinatorManager { request_tx },
139            (join_handle, shutdown_tx),
140        )
141    }
142
143    pub async fn handle_new_request(
144        &self,
145        mut request_stream: SinkWriterRequestStream,
146    ) -> Result<impl Stream<Item = Result<CoordinateResponse, Status>> + use<>, Status> {
147        let (param, vnode_bitmap) = match request_stream.try_next().await? {
148            Some(CoordinateRequest {
149                msg:
150                    Some(Msg::StartRequest(coordinate_request::StartCoordinationRequest {
151                        param: Some(param),
152                        vnode_bitmap: Some(vnode_bitmap),
153                    })),
154            }) => (SinkParam::from_proto(param), Bitmap::from(&vnode_bitmap)),
155            msg => {
156                return Err(Status::invalid_argument(format!(
157                    "expected CoordinateRequest::StartRequest in the first request, get {:?}",
158                    msg
159                )));
160            }
161        };
162        let (response_tx, response_rx) = mpsc::unbounded_channel();
163        self.request_tx
164            .send(ManagerRequest::NewSinkWriter(
165                SinkWriterCoordinationHandle::new(request_stream, response_tx, param, vnode_bitmap),
166            ))
167            .await
168            .map_err(|_| {
169                Status::unavailable(
170                    "unable to send to sink manager worker. The worker may have stopped",
171                )
172            })?;
173
174        Ok(UnboundedReceiverStream::new(response_rx))
175    }
176
177    async fn stop_coordinator(&self, sink_ids: Option<Vec<SinkId>>) {
178        let (tx, rx) = channel();
179        send_await_with_err_check!(
180            self.request_tx,
181            ManagerRequest::StopCoordinator {
182                finish_notifier: tx,
183                sink_ids: sink_ids.clone(),
184            }
185        );
186        if rx.await.is_err() {
187            error!("fail to wait for resetting sink manager worker");
188        }
189        info!("successfully stop coordinator: {:?}", sink_ids);
190    }
191
192    pub async fn reset(&self) {
193        self.stop_coordinator(None).await;
194    }
195
196    pub async fn stop_sink_coordinator(&self, sink_ids: Vec<SinkId>) {
197        self.stop_coordinator(Some(sink_ids)).await;
198    }
199}
200
201struct CoordinatorWorkerHandle {
202    /// Sender to coordinator worker. Drop the sender as a stop signal
203    request_sender: Option<UnboundedSender<SinkWriterCoordinationHandle>>,
204    /// Notify when the coordinator worker stops
205    finish_notifiers: Vec<Sender<()>>,
206}
207
208struct ManagerWorker {
209    request_rx: mpsc::Receiver<ManagerRequest>,
210    // Make it option so that it can be polled with &mut SinkManagerWorker
211    shutdown_rx: Receiver<()>,
212
213    running_coordinator_worker_join_handles:
214        FuturesUnordered<BoxFuture<'static, (SinkId, Result<(), JoinError>)>>,
215    running_coordinator_worker: HashMap<SinkId, CoordinatorWorkerHandle>,
216}
217
218enum ManagerEvent {
219    NewRequest(ManagerRequest),
220    CoordinatorWorkerFinished {
221        sink_id: SinkId,
222        join_result: Result<(), JoinError>,
223    },
224}
225
226trait SpawnCoordinatorFn = FnMut(SinkParam, UnboundedReceiver<SinkWriterCoordinationHandle>) -> JoinHandle<()>
227    + Send
228    + 'static;
229
230impl ManagerWorker {
231    fn new(request_rx: mpsc::Receiver<ManagerRequest>, shutdown_rx: Receiver<()>) -> Self {
232        ManagerWorker {
233            request_rx,
234            shutdown_rx,
235            running_coordinator_worker_join_handles: Default::default(),
236            running_coordinator_worker: Default::default(),
237        }
238    }
239
240    async fn execute(mut self, mut spawn_coordinator_worker: impl SpawnCoordinatorFn) {
241        while let Some(event) = self.next_event().await {
242            match event {
243                ManagerEvent::NewRequest(request) => match request {
244                    ManagerRequest::NewSinkWriter(request) => {
245                        self.handle_new_sink_writer(request, &mut spawn_coordinator_worker)
246                    }
247                    ManagerRequest::StopCoordinator {
248                        finish_notifier,
249                        sink_ids,
250                    } => {
251                        if let Some(sink_ids) = sink_ids {
252                            let mut rxs = Vec::with_capacity(sink_ids.len());
253                            for sink_id in sink_ids {
254                                if let Some(worker_handle) =
255                                    self.running_coordinator_worker.get_mut(&sink_id)
256                                {
257                                    let (tx, rx) = oneshot::channel();
258                                    rxs.push(rx);
259                                    worker_handle.finish_notifiers.push(tx);
260                                    if let Some(sender) = worker_handle.request_sender.take() {
261                                        // drop the sender as a signal to notify the coordinator worker
262                                        // to stop
263                                        drop(sender);
264                                    }
265                                } else {
266                                    debug!(
267                                        "sink coordinator of {} is not running, skip it",
268                                        sink_id
269                                    );
270                                }
271                            }
272                            tokio::spawn(async move {
273                                let notify_res = join_all(rxs).await;
274                                for res in notify_res {
275                                    if let Err(e) = res {
276                                        error!(
277                                            "fail to wait for resetting sink manager worker: {}",
278                                            e.as_report()
279                                        );
280                                    }
281                                }
282                                send_with_err_check!(finish_notifier, ());
283                            });
284                        } else {
285                            self.clean_up().await;
286                            send_with_err_check!(finish_notifier, ());
287                        }
288                    }
289                },
290                ManagerEvent::CoordinatorWorkerFinished {
291                    sink_id,
292                    join_result,
293                } => self.handle_coordinator_finished(sink_id, join_result),
294            }
295        }
296        self.clean_up().await;
297        info!("sink manager worker exited");
298    }
299
300    async fn next_event(&mut self) -> Option<ManagerEvent> {
301        match select(
302            select(
303                pin!(self.request_rx.recv()),
304                pin!(pending_on_none(
305                    self.running_coordinator_worker_join_handles.next()
306                )),
307            ),
308            &mut self.shutdown_rx,
309        )
310        .await
311        {
312            Either::Left((either, _)) => match either {
313                Either::Left((Some(request), _)) => Some(ManagerEvent::NewRequest(request)),
314                Either::Left((None, _)) => None,
315                Either::Right(((sink_id, join_result), _)) => {
316                    Some(ManagerEvent::CoordinatorWorkerFinished {
317                        sink_id,
318                        join_result,
319                    })
320                }
321            },
322            Either::Right(_) => None,
323        }
324    }
325
326    async fn clean_up(&mut self) {
327        info!("sink manager worker start cleaning up");
328        for worker_handle in self.running_coordinator_worker.values_mut() {
329            if let Some(sender) = worker_handle.request_sender.take() {
330                // drop the sender to notify the coordinator worker to stop
331                drop(sender);
332            }
333        }
334        while let Some((sink_id, join_result)) =
335            self.running_coordinator_worker_join_handles.next().await
336        {
337            self.handle_coordinator_finished(sink_id, join_result);
338        }
339        info!("sink manager worker finished cleaning up");
340    }
341
342    fn handle_coordinator_finished(&mut self, sink_id: SinkId, join_result: Result<(), JoinError>) {
343        let worker_handle = self
344            .running_coordinator_worker
345            .remove(&sink_id)
346            .expect("finished coordinator should have an associated worker handle");
347        for finish_notifier in worker_handle.finish_notifiers {
348            send_with_err_check!(finish_notifier, ());
349        }
350        match join_result {
351            Ok(()) => {
352                info!(
353                    id = %sink_id,
354                    "sink coordinator has gracefully finished",
355                );
356            }
357            Err(err) => {
358                error!(
359                    id = %sink_id,
360                    error = %err.as_report(),
361                    "sink coordinator finished with error",
362                );
363            }
364        }
365    }
366
367    fn handle_new_sink_writer(
368        &mut self,
369        new_writer: SinkWriterCoordinationHandle,
370        spawn_coordinator_worker: &mut impl SpawnCoordinatorFn,
371    ) {
372        let param = new_writer.param();
373        let sink_id = param.sink_id;
374
375        let handle = self
376            .running_coordinator_worker
377            .entry(param.sink_id)
378            .or_insert_with(|| {
379                // Launch the coordinator worker task if it is the first
380                let (request_tx, request_rx) = unbounded_channel();
381                let join_handle = spawn_coordinator_worker(param.clone(), request_rx);
382                self.running_coordinator_worker_join_handles.push(
383                    join_handle
384                        .map(move |join_result| (sink_id, join_result))
385                        .boxed(),
386                );
387                CoordinatorWorkerHandle {
388                    request_sender: Some(request_tx),
389                    finish_notifiers: Vec::new(),
390                }
391            });
392
393        if let Some(sender) = handle.request_sender.as_mut() {
394            send_with_err_check!(sender, new_writer);
395        } else {
396            warn!(
397                "handle a new request while the sink coordinator is being stopped: {:?}",
398                param
399            );
400            new_writer.abort(Status::internal("the sink is being stopped"));
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use std::future::{Future, poll_fn};
408    use std::pin::pin;
409    use std::sync::Arc;
410    use std::sync::atomic::AtomicI32;
411    use std::task::Poll;
412
413    use anyhow::anyhow;
414    use async_trait::async_trait;
415    use futures::future::{join, try_join};
416    use futures::{FutureExt, StreamExt, TryFutureExt};
417    use itertools::Itertools;
418    use rand::seq::SliceRandom;
419    use risingwave_common::bitmap::BitmapBuilder;
420    use risingwave_common::hash::VirtualNode;
421    use risingwave_connector::sink::catalog::{SinkId, SinkType};
422    use risingwave_connector::sink::{
423        SinglePhaseCommitCoordinator, SinkCommitCoordinator, SinkError, SinkParam,
424        TwoPhaseCommitCoordinator,
425    };
426    use risingwave_meta_model::SinkSchemachange;
427    use risingwave_pb::connector_service::SinkMetadata;
428    use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata};
429    use risingwave_pb::data::PbDataType;
430    use risingwave_pb::data::data_type::PbTypeName;
431    use risingwave_pb::plan_common::PbField;
432    use risingwave_pb::stream_plan::sink_schema_change::Op as SinkSchemachangeOp;
433    use risingwave_pb::stream_plan::{PbSinkAddColumnsOp, PbSinkSchemaChange};
434    use risingwave_rpc_client::CoordinatorStreamHandle;
435    use sea_orm::{ConnectionTrait, Database, DatabaseConnection};
436    use tokio::sync::mpsc::unbounded_channel;
437    use tokio_stream::wrappers::ReceiverStream;
438
439    use crate::manager::sink_coordination::SinkCoordinatorManager;
440    use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
441    use crate::manager::sink_coordination::manager::SinkCommittedEpochSubscriber;
442
443    struct MockSinglePhaseCoordinator<
444        C,
445        F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>,
446    > {
447        context: C,
448        f: F,
449    }
450
451    impl<
452        C: Send + 'static,
453        F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send + 'static,
454    > MockSinglePhaseCoordinator<C, F>
455    {
456        fn new_coordinator(context: C, f: F) -> SinkCommitCoordinator {
457            SinkCommitCoordinator::SinglePhase(Box::new(MockSinglePhaseCoordinator { context, f }))
458        }
459    }
460
461    #[async_trait]
462    impl<C: Send, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send>
463        SinglePhaseCommitCoordinator for MockSinglePhaseCoordinator<C, F>
464    {
465        async fn init(&mut self) -> risingwave_connector::sink::Result<()> {
466            Ok(())
467        }
468
469        async fn commit_data(
470            &mut self,
471            epoch: u64,
472            metadata: Vec<SinkMetadata>,
473        ) -> risingwave_connector::sink::Result<()> {
474            (self.f)(epoch, metadata, &mut self.context)
475        }
476
477        async fn commit_schema_change(
478            &mut self,
479            _epoch: u64,
480            _schema_change: PbSinkSchemaChange,
481        ) -> risingwave_connector::sink::Result<()> {
482            unreachable!()
483        }
484    }
485
486    #[tokio::test]
487    async fn test_basic() {
488        let db = prepare_db_backend().await;
489
490        let param = SinkParam {
491            sink_id: SinkId::from(1),
492            sink_name: "test".into(),
493            properties: Default::default(),
494            columns: vec![],
495            downstream_pk: None,
496            sink_type: SinkType::AppendOnly,
497            ignore_delete: false,
498            format_desc: None,
499            db_name: "test".into(),
500            sink_from_name: "test".into(),
501        };
502
503        let epoch0 = 232;
504        let epoch1 = 233;
505        let epoch2 = 234;
506
507        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
508        all_vnode.shuffle(&mut rand::rng());
509        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
510        let build_bitmap = |indexes: &[usize]| {
511            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
512            for i in indexes {
513                builder.set(*i, true);
514            }
515            builder.finish()
516        };
517        let vnode1 = build_bitmap(first);
518        let vnode2 = build_bitmap(second);
519
520        let metadata = [
521            [vec![1u8, 2u8], vec![3u8, 4u8]],
522            [vec![5u8, 6u8], vec![7u8, 8u8]],
523        ];
524        let sender = Arc::new(tokio::sync::Mutex::new(None));
525        let mock_subscriber: SinkCommittedEpochSubscriber = {
526            let captured_sender = sender.clone();
527            Arc::new(move |_sink_id: SinkId| {
528                let (sender, receiver) = unbounded_channel();
529                let captured_sender = captured_sender.clone();
530                async move {
531                    let mut guard = captured_sender.lock().await;
532                    *guard = Some(sender);
533                    Ok((1, receiver))
534                }
535                .boxed()
536            })
537        };
538
539        let (manager, (_join_handle, _stop_tx)) =
540            SinkCoordinatorManager::start_worker_with_spawn_worker({
541                let expected_param = param.clone();
542                let metadata = metadata.clone();
543                let db = db.clone();
544                move |param, new_writer_rx| {
545                    let metadata = metadata.clone();
546                    let expected_param = expected_param.clone();
547                    let db = db.clone();
548                    tokio::spawn({
549                        let subscriber = mock_subscriber.clone();
550                        async move {
551                            // validate the start request
552                            assert_eq!(param, expected_param);
553                            CoordinatorWorker::execute_coordinator(
554                                db,
555                                param.clone(),
556                                new_writer_rx,
557                                MockSinglePhaseCoordinator::new_coordinator(
558                                    0,
559                                    move |epoch, metadata_list, count: &mut usize| {
560                                        *count += 1;
561                                        let mut metadata_list =
562                                            metadata_list
563                                                .into_iter()
564                                                .map(|metadata| match metadata {
565                                                    SinkMetadata {
566                                                        metadata:
567                                                            Some(Metadata::Serialized(
568                                                                SerializedMetadata { metadata },
569                                                            )),
570                                                    } => metadata,
571                                                    _ => unreachable!(),
572                                                })
573                                                .collect_vec();
574                                        metadata_list.sort();
575                                        match *count {
576                                            1 => {
577                                                assert_eq!(epoch, epoch1);
578                                                assert_eq!(2, metadata_list.len());
579                                                assert_eq!(metadata[0][0], metadata_list[0]);
580                                                assert_eq!(metadata[0][1], metadata_list[1]);
581                                            }
582                                            2 => {
583                                                assert_eq!(epoch, epoch2);
584                                                assert_eq!(2, metadata_list.len());
585                                                assert_eq!(metadata[1][0], metadata_list[0]);
586                                                assert_eq!(metadata[1][1], metadata_list[1]);
587                                            }
588                                            _ => unreachable!(),
589                                        }
590                                        Ok(())
591                                    },
592                                ),
593                                subscriber.clone(),
594                            )
595                            .await;
596                        }
597                    })
598                }
599            });
600
601        let build_client = |vnode| async {
602            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
603                Ok(tonic::Response::new(
604                    manager
605                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
606                        .await
607                        .unwrap()
608                        .boxed(),
609                ))
610            })
611            .await
612            .unwrap()
613            .0
614        };
615
616        let (mut client1, mut client2) =
617            join(build_client(vnode1), pin!(build_client(vnode2))).await;
618
619        let (aligned_epoch1, aligned_epoch2) = try_join(
620            client1.align_initial_epoch(epoch0),
621            client2.align_initial_epoch(epoch1),
622        )
623        .await
624        .unwrap();
625        assert_eq!(aligned_epoch1, epoch1);
626        assert_eq!(aligned_epoch2, epoch1);
627
628        {
629            // commit epoch1
630            let mut commit_future = pin!(
631                client2
632                    .commit(
633                        epoch1,
634                        SinkMetadata {
635                            metadata: Some(Metadata::Serialized(SerializedMetadata {
636                                metadata: metadata[0][1].clone(),
637                            })),
638                        },
639                        None,
640                    )
641                    .map(|result| result.unwrap())
642            );
643            assert!(
644                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
645                    .await
646                    .is_pending()
647            );
648            join(
649                commit_future,
650                client1
651                    .commit(
652                        epoch1,
653                        SinkMetadata {
654                            metadata: Some(Metadata::Serialized(SerializedMetadata {
655                                metadata: metadata[0][0].clone(),
656                            })),
657                        },
658                        None,
659                    )
660                    .map(|result| result.unwrap()),
661            )
662            .await;
663        }
664
665        // commit epoch2
666        let mut commit_future = pin!(
667            client1
668                .commit(
669                    epoch2,
670                    SinkMetadata {
671                        metadata: Some(Metadata::Serialized(SerializedMetadata {
672                            metadata: metadata[1][0].clone(),
673                        })),
674                    },
675                    None,
676                )
677                .map(|result| result.unwrap())
678        );
679        assert!(
680            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
681                .await
682                .is_pending()
683        );
684        join(
685            commit_future,
686            client2
687                .commit(
688                    epoch2,
689                    SinkMetadata {
690                        metadata: Some(Metadata::Serialized(SerializedMetadata {
691                            metadata: metadata[1][1].clone(),
692                        })),
693                    },
694                    None,
695                )
696                .map(|result| result.unwrap()),
697        )
698        .await;
699    }
700
701    #[tokio::test]
702    async fn test_single_writer() {
703        let db = prepare_db_backend().await;
704        let param = SinkParam {
705            sink_id: SinkId::from(1),
706            sink_name: "test".into(),
707            properties: Default::default(),
708            columns: vec![],
709            downstream_pk: None,
710            sink_type: SinkType::AppendOnly,
711            ignore_delete: false,
712            format_desc: None,
713            db_name: "test".into(),
714            sink_from_name: "test".into(),
715        };
716
717        let epoch1 = 233;
718        let epoch2 = 234;
719
720        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
721        let build_bitmap = |indexes: &[usize]| {
722            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
723            for i in indexes {
724                builder.set(*i, true);
725            }
726            builder.finish()
727        };
728        let vnode = build_bitmap(&all_vnode);
729
730        let metadata = [vec![1u8, 2u8], vec![3u8, 4u8]];
731        let sender = Arc::new(tokio::sync::Mutex::new(None));
732        let mock_subscriber: SinkCommittedEpochSubscriber = {
733            let captured_sender = sender.clone();
734            Arc::new(move |_sink_id: SinkId| {
735                let (sender, receiver) = unbounded_channel();
736                let captured_sender = captured_sender.clone();
737                async move {
738                    let mut guard = captured_sender.lock().await;
739                    *guard = Some(sender);
740                    Ok((1, receiver))
741                }
742                .boxed()
743            })
744        };
745        let (manager, (_join_handle, _stop_tx)) =
746            SinkCoordinatorManager::start_worker_with_spawn_worker({
747                let expected_param = param.clone();
748                let metadata = metadata.clone();
749                let db = db.clone();
750                move |param, new_writer_rx| {
751                    let metadata = metadata.clone();
752                    let expected_param = expected_param.clone();
753                    let db = db.clone();
754                    tokio::spawn({
755                        let subscriber = mock_subscriber.clone();
756                        async move {
757                            // validate the start request
758                            assert_eq!(param, expected_param);
759                            CoordinatorWorker::execute_coordinator(
760                                db,
761                                param.clone(),
762                                new_writer_rx,
763                                MockSinglePhaseCoordinator::new_coordinator(
764                                    0,
765                                    move |epoch, metadata_list, count: &mut usize| {
766                                        *count += 1;
767                                        let mut metadata_list =
768                                            metadata_list
769                                                .into_iter()
770                                                .map(|metadata| match metadata {
771                                                    SinkMetadata {
772                                                        metadata:
773                                                            Some(Metadata::Serialized(
774                                                                SerializedMetadata { metadata },
775                                                            )),
776                                                    } => metadata,
777                                                    _ => unreachable!(),
778                                                })
779                                                .collect_vec();
780                                        metadata_list.sort();
781                                        match *count {
782                                            1 => {
783                                                assert_eq!(epoch, epoch1);
784                                                assert_eq!(1, metadata_list.len());
785                                                assert_eq!(metadata[0], metadata_list[0]);
786                                            }
787                                            2 => {
788                                                assert_eq!(epoch, epoch2);
789                                                assert_eq!(1, metadata_list.len());
790                                                assert_eq!(metadata[1], metadata_list[0]);
791                                            }
792                                            _ => unreachable!(),
793                                        }
794                                        Ok(())
795                                    },
796                                ),
797                                subscriber.clone(),
798                            )
799                            .await;
800                        }
801                    })
802                }
803            });
804
805        let build_client = |vnode| async {
806            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
807                Ok(tonic::Response::new(
808                    manager
809                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
810                        .await
811                        .unwrap()
812                        .boxed(),
813                ))
814            })
815            .await
816            .unwrap()
817            .0
818        };
819
820        let mut client = build_client(vnode).await;
821
822        let aligned_epoch = client.align_initial_epoch(epoch1).await.unwrap();
823        assert_eq!(aligned_epoch, epoch1);
824
825        client
826            .commit(
827                epoch1,
828                SinkMetadata {
829                    metadata: Some(Metadata::Serialized(SerializedMetadata {
830                        metadata: metadata[0].clone(),
831                    })),
832                },
833                None,
834            )
835            .await
836            .unwrap();
837
838        client
839            .commit(
840                epoch2,
841                SinkMetadata {
842                    metadata: Some(Metadata::Serialized(SerializedMetadata {
843                        metadata: metadata[1].clone(),
844                    })),
845                },
846                None,
847            )
848            .await
849            .unwrap();
850    }
851
852    #[tokio::test]
853    async fn test_partial_commit() {
854        let db = prepare_db_backend().await;
855        let param = SinkParam {
856            sink_id: SinkId::from(1),
857            sink_name: "test".into(),
858            properties: Default::default(),
859            columns: vec![],
860            downstream_pk: None,
861            sink_type: SinkType::AppendOnly,
862            ignore_delete: false,
863            format_desc: None,
864            db_name: "test".into(),
865            sink_from_name: "test".into(),
866        };
867
868        let epoch = 233;
869
870        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
871        all_vnode.shuffle(&mut rand::rng());
872        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
873        let build_bitmap = |indexes: &[usize]| {
874            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
875            for i in indexes {
876                builder.set(*i, true);
877            }
878            builder.finish()
879        };
880        let vnode1 = build_bitmap(first);
881        let vnode2 = build_bitmap(second);
882
883        let sender = Arc::new(tokio::sync::Mutex::new(None));
884        let mock_subscriber: SinkCommittedEpochSubscriber = {
885            let captured_sender = sender.clone();
886            Arc::new(move |_sink_id: SinkId| {
887                let (sender, receiver) = unbounded_channel();
888                let captured_sender = captured_sender.clone();
889                async move {
890                    let mut guard = captured_sender.lock().await;
891                    *guard = Some(sender);
892                    Ok((1, receiver))
893                }
894                .boxed()
895            })
896        };
897        let (manager, (_join_handle, _stop_tx)) =
898            SinkCoordinatorManager::start_worker_with_spawn_worker({
899                let expected_param = param.clone();
900                let db = db.clone();
901                move |param, new_writer_rx| {
902                    let expected_param = expected_param.clone();
903                    let db = db.clone();
904                    tokio::spawn({
905                        let subscriber = mock_subscriber.clone();
906                        async move {
907                            // validate the start request
908                            assert_eq!(param, expected_param);
909                            CoordinatorWorker::execute_coordinator(
910                                db,
911                                param,
912                                new_writer_rx,
913                                MockSinglePhaseCoordinator::new_coordinator(
914                                    (),
915                                    |_, _, _| unreachable!(),
916                                ),
917                                subscriber.clone(),
918                            )
919                            .await;
920                        }
921                    })
922                }
923            });
924
925        let build_client = |vnode| async {
926            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
927                Ok(tonic::Response::new(
928                    manager
929                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
930                        .await
931                        .unwrap()
932                        .boxed(),
933                ))
934            })
935            .await
936            .unwrap()
937            .0
938        };
939
940        let (mut client1, client2) = join(build_client(vnode1), build_client(vnode2)).await;
941
942        // commit epoch
943        let mut commit_future = pin!(client1.commit(
944            epoch,
945            SinkMetadata {
946                metadata: Some(Metadata::Serialized(SerializedMetadata {
947                    metadata: vec![],
948                })),
949            },
950            None,
951        ));
952        assert!(
953            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
954                .await
955                .is_pending()
956        );
957        drop(client2);
958        assert!(commit_future.await.is_err());
959    }
960
961    #[tokio::test]
962    async fn test_fail_commit() {
963        let db = prepare_db_backend().await;
964        let param = SinkParam {
965            sink_id: SinkId::from(1),
966            sink_name: "test".into(),
967            properties: Default::default(),
968            columns: vec![],
969            downstream_pk: None,
970            sink_type: SinkType::AppendOnly,
971            ignore_delete: false,
972            format_desc: None,
973            db_name: "test".into(),
974            sink_from_name: "test".into(),
975        };
976
977        let epoch = 233;
978
979        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
980        all_vnode.shuffle(&mut rand::rng());
981        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
982        let build_bitmap = |indexes: &[usize]| {
983            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
984            for i in indexes {
985                builder.set(*i, true);
986            }
987            builder.finish()
988        };
989        let vnode1 = build_bitmap(first);
990        let vnode2 = build_bitmap(second);
991        let sender = Arc::new(tokio::sync::Mutex::new(None));
992        let mock_subscriber: SinkCommittedEpochSubscriber = {
993            let captured_sender = sender.clone();
994            Arc::new(move |_sink_id: SinkId| {
995                let (sender, receiver) = unbounded_channel();
996                let captured_sender = captured_sender.clone();
997                async move {
998                    let mut guard = captured_sender.lock().await;
999                    *guard = Some(sender);
1000                    Ok((1, receiver))
1001                }
1002                .boxed()
1003            })
1004        };
1005        let (manager, (_join_handle, _stop_tx)) =
1006            SinkCoordinatorManager::start_worker_with_spawn_worker({
1007                let expected_param = param.clone();
1008                let db = db.clone();
1009                move |param, new_writer_rx| {
1010                    let expected_param = expected_param.clone();
1011                    let db = db.clone();
1012                    tokio::spawn({
1013                        let subscriber = mock_subscriber.clone();
1014                        {
1015                            async move {
1016                                // validate the start request
1017                                assert_eq!(param, expected_param);
1018                                CoordinatorWorker::execute_coordinator(
1019                                    db,
1020                                    param,
1021                                    new_writer_rx,
1022                                    MockSinglePhaseCoordinator::new_coordinator((), |_, _, _| {
1023                                        Err(SinkError::Coordinator(anyhow!("failed to commit")))
1024                                    }),
1025                                    subscriber.clone(),
1026                                )
1027                                .await;
1028                            }
1029                        }
1030                    })
1031                }
1032            });
1033
1034        let build_client = |vnode| async {
1035            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1036                Ok(tonic::Response::new(
1037                    manager
1038                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1039                        .await
1040                        .unwrap()
1041                        .boxed(),
1042                ))
1043            })
1044            .await
1045            .unwrap()
1046            .0
1047        };
1048
1049        let (mut client1, mut client2) = join(build_client(vnode1), build_client(vnode2)).await;
1050
1051        // commit epoch
1052        let mut commit_future = pin!(client1.commit(
1053            epoch,
1054            SinkMetadata {
1055                metadata: Some(Metadata::Serialized(SerializedMetadata {
1056                    metadata: vec![],
1057                })),
1058            },
1059            None,
1060        ));
1061        assert!(
1062            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1063                .await
1064                .is_pending()
1065        );
1066        let (result1, result2) = join(
1067            commit_future,
1068            client2.commit(
1069                epoch,
1070                SinkMetadata {
1071                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1072                        metadata: vec![],
1073                    })),
1074                },
1075                None,
1076            ),
1077        )
1078        .await;
1079        assert!(result1.is_err());
1080        assert!(result2.is_err());
1081    }
1082
1083    #[tokio::test]
1084    async fn test_update_vnode_bitmap() {
1085        let db = prepare_db_backend().await;
1086        let param = SinkParam {
1087            sink_id: SinkId::from(1),
1088            sink_name: "test".into(),
1089            properties: Default::default(),
1090            columns: vec![],
1091            downstream_pk: None,
1092            sink_type: SinkType::AppendOnly,
1093            ignore_delete: false,
1094            format_desc: None,
1095            db_name: "test".into(),
1096            sink_from_name: "test".into(),
1097        };
1098
1099        let epoch1 = 233;
1100        let epoch2 = 234;
1101        let epoch3 = 235;
1102        let epoch4 = 236;
1103
1104        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1105        all_vnode.shuffle(&mut rand::rng());
1106        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
1107        let build_bitmap = |indexes: &[usize]| {
1108            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1109            for i in indexes {
1110                builder.set(*i, true);
1111            }
1112            builder.finish()
1113        };
1114        let vnode1 = build_bitmap(first);
1115        let vnode2 = build_bitmap(second);
1116
1117        let metadata = [
1118            [vec![1u8, 2u8], vec![3u8, 4u8]],
1119            [vec![5u8, 6u8], vec![7u8, 8u8]],
1120        ];
1121
1122        let metadata_scale_out = [vec![9u8, 10u8], vec![11u8, 12u8], vec![13u8, 14u8]];
1123        let metadata_scale_in = [vec![13u8, 14u8], vec![15u8, 16u8]];
1124        let sender = Arc::new(tokio::sync::Mutex::new(None));
1125        let mock_subscriber: SinkCommittedEpochSubscriber = {
1126            let captured_sender = sender.clone();
1127            Arc::new(move |_sink_id: SinkId| {
1128                let (sender, receiver) = unbounded_channel();
1129                let captured_sender = captured_sender.clone();
1130                async move {
1131                    let mut guard = captured_sender.lock().await;
1132                    *guard = Some(sender);
1133                    Ok((1, receiver))
1134                }
1135                .boxed()
1136            })
1137        };
1138        let (manager, (_join_handle, _stop_tx)) =
1139            SinkCoordinatorManager::start_worker_with_spawn_worker({
1140                let expected_param = param.clone();
1141                let metadata = metadata.clone();
1142                let metadata_scale_out = metadata_scale_out.clone();
1143                let metadata_scale_in = metadata_scale_in.clone();
1144                let db = db.clone();
1145                move |param, new_writer_rx| {
1146                    let metadata = metadata.clone();
1147                    let metadata_scale_out = metadata_scale_out.clone();
1148                    let metadata_scale_in = metadata_scale_in.clone();
1149                    let expected_param = expected_param.clone();
1150                    let db = db.clone();
1151                    tokio::spawn({
1152                        let subscriber = mock_subscriber.clone();
1153                        async move {
1154                            // validate the start request
1155                            assert_eq!(param, expected_param);
1156                            CoordinatorWorker::execute_coordinator(
1157                                db,
1158                                param.clone(),
1159                                new_writer_rx,
1160                                MockSinglePhaseCoordinator::new_coordinator(
1161                                    0,
1162                                    move |epoch, metadata_list, count: &mut usize| {
1163                                        *count += 1;
1164                                        let mut metadata_list =
1165                                            metadata_list
1166                                                .into_iter()
1167                                                .map(|metadata| match metadata {
1168                                                    SinkMetadata {
1169                                                        metadata:
1170                                                            Some(Metadata::Serialized(
1171                                                                SerializedMetadata { metadata },
1172                                                            )),
1173                                                    } => metadata,
1174                                                    _ => unreachable!(),
1175                                                })
1176                                                .collect_vec();
1177                                        metadata_list.sort();
1178                                        let (expected_epoch, expected_metadata_list) = match *count
1179                                        {
1180                                            1 => (epoch1, metadata[0].as_slice()),
1181                                            2 => (epoch2, metadata[1].as_slice()),
1182                                            3 => (epoch3, metadata_scale_out.as_slice()),
1183                                            4 => (epoch4, metadata_scale_in.as_slice()),
1184                                            _ => unreachable!(),
1185                                        };
1186                                        assert_eq!(expected_epoch, epoch);
1187                                        assert_eq!(expected_metadata_list, &metadata_list);
1188                                        Ok(())
1189                                    },
1190                                ),
1191                                subscriber.clone(),
1192                            )
1193                            .await;
1194                        }
1195                    })
1196                }
1197            });
1198
1199        let build_client = |vnode| async {
1200            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1201                Ok(tonic::Response::new(
1202                    manager
1203                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1204                        .await
1205                        .unwrap()
1206                        .boxed(),
1207                ))
1208            })
1209            .await
1210        };
1211
1212        let ((mut client1, _), (mut client2, _)) =
1213            try_join(build_client(vnode1), pin!(build_client(vnode2)))
1214                .await
1215                .unwrap();
1216
1217        let (aligned_epoch1, aligned_epoch2) = try_join(
1218            client1.align_initial_epoch(epoch1),
1219            client2.align_initial_epoch(epoch1),
1220        )
1221        .await
1222        .unwrap();
1223        assert_eq!(aligned_epoch1, epoch1);
1224        assert_eq!(aligned_epoch2, epoch1);
1225
1226        {
1227            // commit epoch1
1228            let mut commit_future = pin!(
1229                client2
1230                    .commit(
1231                        epoch1,
1232                        SinkMetadata {
1233                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1234                                metadata: metadata[0][1].clone(),
1235                            })),
1236                        },
1237                        None,
1238                    )
1239                    .map(|result| result.unwrap())
1240            );
1241            assert!(
1242                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1243                    .await
1244                    .is_pending()
1245            );
1246            join(
1247                commit_future,
1248                client1
1249                    .commit(
1250                        epoch1,
1251                        SinkMetadata {
1252                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1253                                metadata: metadata[0][0].clone(),
1254                            })),
1255                        },
1256                        None,
1257                    )
1258                    .map(|result| result.unwrap()),
1259            )
1260            .await;
1261        }
1262
1263        let (vnode1, vnode2, vnode3) = {
1264            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1265            let (second, third) = second.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1266            (
1267                build_bitmap(first),
1268                build_bitmap(second),
1269                build_bitmap(third),
1270            )
1271        };
1272
1273        let mut build_client3_future = pin!(build_client(vnode3));
1274        assert!(
1275            poll_fn(|cx| Poll::Ready(build_client3_future.as_mut().poll(cx)))
1276                .await
1277                .is_pending()
1278        );
1279        let mut client3;
1280        {
1281            {
1282                // commit epoch2
1283                let mut commit_future = pin!(
1284                    client1
1285                        .commit(
1286                            epoch2,
1287                            SinkMetadata {
1288                                metadata: Some(Metadata::Serialized(SerializedMetadata {
1289                                    metadata: metadata[1][0].clone(),
1290                                })),
1291                            },
1292                            None,
1293                        )
1294                        .map_err(Into::into)
1295                );
1296                assert!(
1297                    poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1298                        .await
1299                        .is_pending()
1300                );
1301                try_join(
1302                    commit_future,
1303                    client2.commit(
1304                        epoch2,
1305                        SinkMetadata {
1306                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1307                                metadata: metadata[1][1].clone(),
1308                            })),
1309                        },
1310                        None,
1311                    ),
1312                )
1313                .await
1314                .unwrap();
1315            }
1316
1317            client3 = {
1318                let (
1319                    (client3, init_epoch),
1320                    (update_vnode_bitmap_epoch1, update_vnode_bitmap_epoch2),
1321                ) = try_join(
1322                    build_client3_future,
1323                    try_join(
1324                        client1.update_vnode_bitmap(&vnode1),
1325                        client2.update_vnode_bitmap(&vnode2),
1326                    )
1327                    .map_err(Into::into),
1328                )
1329                .await
1330                .unwrap();
1331                assert_eq!(init_epoch, Some(epoch2));
1332                assert_eq!(update_vnode_bitmap_epoch1, epoch2);
1333                assert_eq!(update_vnode_bitmap_epoch2, epoch2);
1334                client3
1335            };
1336            let mut commit_future3 = pin!(client3.commit(
1337                epoch3,
1338                SinkMetadata {
1339                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1340                        metadata: metadata_scale_out[2].clone(),
1341                    })),
1342                },
1343                None,
1344            ));
1345            assert!(
1346                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1347                    .await
1348                    .is_pending()
1349            );
1350            let mut commit_future1 = pin!(client1.commit(
1351                epoch3,
1352                SinkMetadata {
1353                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1354                        metadata: metadata_scale_out[0].clone(),
1355                    })),
1356                },
1357                None,
1358            ));
1359            assert!(
1360                poll_fn(|cx| Poll::Ready(commit_future1.as_mut().poll(cx)))
1361                    .await
1362                    .is_pending()
1363            );
1364            assert!(
1365                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1366                    .await
1367                    .is_pending()
1368            );
1369            try_join(
1370                client2.commit(
1371                    epoch3,
1372                    SinkMetadata {
1373                        metadata: Some(Metadata::Serialized(SerializedMetadata {
1374                            metadata: metadata_scale_out[1].clone(),
1375                        })),
1376                    },
1377                    None,
1378                ),
1379                try_join(commit_future1, commit_future3),
1380            )
1381            .await
1382            .unwrap();
1383        }
1384
1385        let (vnode2, vnode3) = {
1386            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1387            (build_bitmap(first), build_bitmap(second))
1388        };
1389
1390        {
1391            let (_, (update_vnode_bitmap_epoch2, update_vnode_bitmap_epoch3)) = try_join(
1392                client1.stop(),
1393                try_join(
1394                    client2.update_vnode_bitmap(&vnode2),
1395                    client3.update_vnode_bitmap(&vnode3),
1396                ),
1397            )
1398            .await
1399            .unwrap();
1400            assert_eq!(update_vnode_bitmap_epoch2, epoch3);
1401            assert_eq!(update_vnode_bitmap_epoch3, epoch3);
1402        }
1403
1404        {
1405            let mut commit_future = pin!(
1406                client2
1407                    .commit(
1408                        epoch4,
1409                        SinkMetadata {
1410                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1411                                metadata: metadata_scale_in[0].clone(),
1412                            })),
1413                        },
1414                        None,
1415                    )
1416                    .map(|result| result.unwrap())
1417            );
1418            assert!(
1419                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1420                    .await
1421                    .is_pending()
1422            );
1423            join(
1424                commit_future,
1425                client3
1426                    .commit(
1427                        epoch4,
1428                        SinkMetadata {
1429                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1430                                metadata: metadata_scale_in[1].clone(),
1431                            })),
1432                        },
1433                        None,
1434                    )
1435                    .map(|result| result.unwrap()),
1436            )
1437            .await;
1438        }
1439    }
1440
1441    struct MockTwoPhaseCoordinator<
1442        P: FnMut(
1443            u64,
1444            Vec<SinkMetadata>,
1445            Option<PbSinkSchemaChange>,
1446        ) -> Result<Option<Vec<u8>>, SinkError>,
1447        CD: FnMut(u64, Vec<u8>) -> Result<(), SinkError>,
1448        CS: FnMut(u64, PbSinkSchemaChange) -> Result<(), SinkError>,
1449    > {
1450        pre_commit: P,
1451        commit_data: CD,
1452        commit_schema_change: CS,
1453    }
1454
1455    impl<
1456        P: FnMut(
1457                u64,
1458                Vec<SinkMetadata>,
1459                Option<PbSinkSchemaChange>,
1460            ) -> Result<Option<Vec<u8>>, SinkError>
1461            + Send
1462            + 'static,
1463        CD: FnMut(u64, Vec<u8>) -> Result<(), SinkError> + Send + 'static,
1464        CS: FnMut(u64, PbSinkSchemaChange) -> Result<(), SinkError> + Send + 'static,
1465    > MockTwoPhaseCoordinator<P, CD, CS>
1466    {
1467        fn new_coordinator(
1468            pre_commit: P,
1469            commit_data: CD,
1470            commit_schema_change: CS,
1471        ) -> SinkCommitCoordinator {
1472            SinkCommitCoordinator::TwoPhase(Box::new(MockTwoPhaseCoordinator {
1473                pre_commit,
1474                commit_data,
1475                commit_schema_change,
1476            }))
1477        }
1478    }
1479
1480    #[async_trait]
1481    impl<
1482        P: FnMut(
1483                u64,
1484                Vec<SinkMetadata>,
1485                Option<PbSinkSchemaChange>,
1486            ) -> Result<Option<Vec<u8>>, SinkError>
1487            + Send
1488            + 'static,
1489        CD: FnMut(u64, Vec<u8>) -> Result<(), SinkError> + Send + 'static,
1490        CS: FnMut(u64, PbSinkSchemaChange) -> Result<(), SinkError> + Send + 'static,
1491    > TwoPhaseCommitCoordinator for MockTwoPhaseCoordinator<P, CD, CS>
1492    {
1493        async fn init(&mut self) -> risingwave_connector::sink::Result<()> {
1494            Ok(())
1495        }
1496
1497        async fn pre_commit(
1498            &mut self,
1499            epoch: u64,
1500            metadata: Vec<SinkMetadata>,
1501            schema_change: Option<PbSinkSchemaChange>,
1502        ) -> risingwave_connector::sink::Result<Option<Vec<u8>>> {
1503            (self.pre_commit)(epoch, metadata, schema_change)
1504        }
1505
1506        async fn commit_data(
1507            &mut self,
1508            epoch: u64,
1509            commit_metadata: Vec<u8>,
1510        ) -> risingwave_connector::sink::Result<()> {
1511            (self.commit_data)(epoch, commit_metadata)
1512        }
1513
1514        async fn commit_schema_change(
1515            &mut self,
1516            epoch: u64,
1517            schema_change: PbSinkSchemaChange,
1518        ) -> risingwave_connector::sink::Result<()> {
1519            (self.commit_schema_change)(epoch, schema_change)
1520        }
1521
1522        async fn abort(&mut self, _epoch: u64, _commit_metadata: Vec<u8>) {
1523            tracing::debug!("abort called");
1524        }
1525    }
1526
1527    async fn prepare_db_backend() -> DatabaseConnection {
1528        let db: DatabaseConnection = Database::connect("sqlite::memory:").await.unwrap();
1529        let ddl = "
1530            CREATE TABLE IF NOT EXISTS pending_sink_state (
1531                sink_id i32 NOT NULL,
1532                epoch i64 NOT NULL,
1533                sink_state STRING NOT NULL,
1534                metadata BLOB NOT NULL,
1535                schema_change BLOB,
1536                PRIMARY KEY (sink_id, epoch)
1537            )
1538        ";
1539        db.execute(sea_orm::Statement::from_string(
1540            db.get_database_backend(),
1541            ddl.to_owned(),
1542        ))
1543        .await
1544        .unwrap();
1545        db
1546    }
1547
1548    async fn list_rows(
1549        db: &DatabaseConnection,
1550    ) -> Vec<(i32, i64, String, Vec<u8>, Option<PbSinkSchemaChange>)> {
1551        let sql =
1552            "SELECT sink_id, epoch, sink_state, metadata, schema_change FROM pending_sink_state";
1553        let rows = db
1554            .query_all(sea_orm::Statement::from_string(
1555                db.get_database_backend(),
1556                sql.to_owned(),
1557            ))
1558            .await
1559            .unwrap();
1560        rows.into_iter()
1561            .map(|row| {
1562                (
1563                    row.try_get("", "sink_id").unwrap(),
1564                    row.try_get("", "epoch").unwrap(),
1565                    row.try_get("", "sink_state").unwrap(),
1566                    row.try_get("", "metadata").unwrap(),
1567                    row.try_get::<Option<SinkSchemachange>>("", "schema_change")
1568                        .unwrap()
1569                        .map(|v| v.to_protobuf()),
1570                )
1571            })
1572            .collect()
1573    }
1574
1575    async fn set_epoch_aborted(db: &DatabaseConnection, sink_id: SinkId, epoch: u64) {
1576        let sql = format!(
1577            "UPDATE pending_sink_state SET sink_state = 'ABORTED' WHERE sink_id = {} AND epoch = {}",
1578            sink_id, epoch as i64
1579        );
1580        db.execute(sea_orm::Statement::from_string(
1581            db.get_database_backend(),
1582            sql,
1583        ))
1584        .await
1585        .unwrap();
1586    }
1587
1588    #[tokio::test]
1589    async fn test_init_response_skips_recovered_pending_epoch() {
1590        let db = prepare_db_backend().await;
1591
1592        let param = SinkParam {
1593            sink_id: SinkId::from(1),
1594            sink_name: "test".into(),
1595            properties: Default::default(),
1596            columns: vec![],
1597            downstream_pk: None,
1598            sink_type: SinkType::AppendOnly,
1599            ignore_delete: false,
1600            format_desc: None,
1601            db_name: "test".into(),
1602            sink_from_name: "test".into(),
1603        };
1604
1605        let epoch1 = 233;
1606        let metadata = vec![1u8, 2u8];
1607
1608        let sql = format!(
1609            "INSERT INTO pending_sink_state (sink_id, epoch, sink_state, metadata) VALUES ({}, {}, 'PENDING', x'0102')",
1610            param.sink_id, epoch1
1611        );
1612        db.execute(sea_orm::Statement::from_string(
1613            db.get_database_backend(),
1614            sql,
1615        ))
1616        .await
1617        .unwrap();
1618
1619        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1620        let build_bitmap = |indexes: &[usize]| {
1621            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1622            for i in indexes {
1623                builder.set(*i, true);
1624            }
1625            builder.finish()
1626        };
1627        let vnode = build_bitmap(&all_vnode);
1628
1629        let sender = Arc::new(tokio::sync::Mutex::new(None));
1630        let mock_subscriber: SinkCommittedEpochSubscriber = {
1631            let captured_sender = sender.clone();
1632            Arc::new(move |_sink_id: SinkId| {
1633                let (epoch_sender, receiver) = unbounded_channel();
1634                let captured_sender = captured_sender.clone();
1635                async move {
1636                    let mut guard = captured_sender.lock().await;
1637                    *guard = Some(epoch_sender);
1638                    Ok((epoch1, receiver))
1639                }
1640                .boxed()
1641            })
1642        };
1643
1644        let pre_commit_attempt = Arc::new(AtomicI32::new(0));
1645        let commit_attempt = Arc::new(AtomicI32::new(0));
1646        let (manager, (_join_handle, _stop_tx)) =
1647            SinkCoordinatorManager::start_worker_with_spawn_worker({
1648                let expected_param = param.clone();
1649                let db = db.clone();
1650                let metadata = metadata.clone();
1651                let pre_commit_attempt = pre_commit_attempt.clone();
1652                let commit_attempt = commit_attempt.clone();
1653                move |param, new_writer_rx| {
1654                    let expected_param = expected_param.clone();
1655                    let db = db.clone();
1656                    let metadata = metadata.clone();
1657                    let pre_commit_attempt = pre_commit_attempt.clone();
1658                    let commit_attempt = commit_attempt.clone();
1659                    tokio::spawn({
1660                        let subscriber = mock_subscriber.clone();
1661                        async move {
1662                            assert_eq!(param, expected_param);
1663                            CoordinatorWorker::execute_coordinator(
1664                                db,
1665                                param.clone(),
1666                                new_writer_rx,
1667                                MockTwoPhaseCoordinator::new_coordinator(
1668                                    move |_epoch, _metadata_list, _schema_change| {
1669                                        pre_commit_attempt.fetch_add(
1670                                            1,
1671                                            std::sync::atomic::Ordering::SeqCst,
1672                                        );
1673                                        Err(SinkError::Coordinator(anyhow!(
1674                                            "pre_commit should not be called for a known pending epoch"
1675                                        )))
1676                                    },
1677                                    move |epoch, commit_metadata| {
1678                                        assert_eq!(epoch, epoch1);
1679                                        assert_eq!(commit_metadata, metadata);
1680                                        commit_attempt
1681                                            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1682                                        Ok(())
1683                                    },
1684                                    move |_epoch, _schema_change| unreachable!(),
1685                                ),
1686                                subscriber.clone(),
1687                            )
1688                            .await;
1689                        }
1690                    })
1691                }
1692            });
1693
1694        let (_client, log_store_rewind_start_epoch) =
1695            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1696                Ok(tonic::Response::new(
1697                    manager
1698                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1699                        .await
1700                        .unwrap()
1701                        .boxed(),
1702                ))
1703            })
1704            .await
1705            .unwrap();
1706        assert_eq!(log_store_rewind_start_epoch, Some(epoch1));
1707
1708        for _ in 0..50 {
1709            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1710            let rows = list_rows(&db).await;
1711            if rows[0].2 == "COMMITTED" {
1712                break;
1713            }
1714        }
1715
1716        assert_eq!(
1717            pre_commit_attempt.load(std::sync::atomic::Ordering::SeqCst),
1718            0
1719        );
1720        assert_eq!(commit_attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
1721        let rows = list_rows(&db).await;
1722        assert_eq!(rows.len(), 1);
1723        assert_eq!(rows[0].1, epoch1 as i64);
1724        assert_eq!(rows[0].2, "COMMITTED");
1725    }
1726
1727    #[tokio::test]
1728    async fn test_pre_commit_failed() {
1729        let db = prepare_db_backend().await;
1730
1731        let param = SinkParam {
1732            sink_id: SinkId::from(1),
1733            sink_name: "test".into(),
1734            properties: Default::default(),
1735            columns: vec![],
1736            downstream_pk: None,
1737            sink_type: SinkType::AppendOnly,
1738            ignore_delete: false,
1739            format_desc: None,
1740            db_name: "test".into(),
1741            sink_from_name: "test".into(),
1742        };
1743
1744        let epoch1 = 233;
1745
1746        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1747        let build_bitmap = |indexes: &[usize]| {
1748            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1749            for i in indexes {
1750                builder.set(*i, true);
1751            }
1752            builder.finish()
1753        };
1754        let vnode = build_bitmap(&all_vnode);
1755
1756        let metadata = vec![1u8, 2u8];
1757        let sender = Arc::new(tokio::sync::Mutex::new(None));
1758        let mock_subscriber: SinkCommittedEpochSubscriber = {
1759            let captured_sender = sender.clone();
1760            Arc::new(move |_sink_id: SinkId| {
1761                let (sender, receiver) = unbounded_channel();
1762                let captured_sender = captured_sender.clone();
1763                async move {
1764                    let mut guard = captured_sender.lock().await;
1765                    *guard = Some(sender);
1766                    Ok((epoch1, receiver))
1767                }
1768                .boxed()
1769            })
1770        };
1771
1772        let (manager, (_join_handle, _stop_tx)) =
1773            SinkCoordinatorManager::start_worker_with_spawn_worker({
1774                let expected_param = param.clone();
1775                let db = db.clone();
1776                move |param, new_writer_rx| {
1777                    let expected_param = expected_param.clone();
1778                    let db = db.clone();
1779                    tokio::spawn({
1780                        let subscriber = mock_subscriber.clone();
1781                        async move {
1782                            // validate the start request
1783                            assert_eq!(param, expected_param);
1784                            CoordinatorWorker::execute_coordinator(
1785                                db,
1786                                param.clone(),
1787                                new_writer_rx,
1788                                MockTwoPhaseCoordinator::new_coordinator(
1789                                    move |_epoch, _metadata_list, _schema_change| {
1790                                        Err(SinkError::Coordinator(anyhow!("failed to pre commit")))
1791                                    },
1792                                    move |_epoch, _commit_metadata| unreachable!(),
1793                                    move |_epoch, _schema_change| unreachable!(),
1794                                ),
1795                                subscriber.clone(),
1796                            )
1797                            .await;
1798                        }
1799                    })
1800                }
1801            });
1802
1803        let build_client = |vnode| async {
1804            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1805                Ok(tonic::Response::new(
1806                    manager
1807                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1808                        .await
1809                        .unwrap()
1810                        .boxed(),
1811                ))
1812            })
1813            .await
1814            .unwrap()
1815            .0
1816        };
1817
1818        let mut client = build_client(vnode).await;
1819
1820        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1821        assert_eq!(aligned_epoch, 1);
1822
1823        let commit_result = client
1824            .commit(
1825                epoch1,
1826                SinkMetadata {
1827                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1828                        metadata: metadata.clone(),
1829                    })),
1830                },
1831                None,
1832            )
1833            .await;
1834        assert!(commit_result.is_err());
1835
1836        let rows = list_rows(&db).await;
1837        assert!(rows.is_empty());
1838    }
1839
1840    #[tokio::test]
1841    async fn test_waiting_on_checkpoint() {
1842        let db = prepare_db_backend().await;
1843
1844        let param = SinkParam {
1845            sink_id: SinkId::from(1),
1846            sink_name: "test".into(),
1847            properties: Default::default(),
1848            columns: vec![],
1849            downstream_pk: None,
1850            sink_type: SinkType::AppendOnly,
1851            ignore_delete: false,
1852            format_desc: None,
1853            db_name: "test".into(),
1854            sink_from_name: "test".into(),
1855        };
1856
1857        let epoch0 = 232;
1858        let epoch1 = 233;
1859
1860        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
1861        let build_bitmap = |indexes: &[usize]| {
1862            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
1863            for i in indexes {
1864                builder.set(*i, true);
1865            }
1866            builder.finish()
1867        };
1868        let vnode = build_bitmap(&all_vnode);
1869
1870        let metadata = vec![1u8, 2u8];
1871
1872        let sender = Arc::new(tokio::sync::Mutex::new(None));
1873        let mock_subscriber: SinkCommittedEpochSubscriber = {
1874            let captured_sender = sender.clone();
1875            Arc::new(move |_sink_id: SinkId| {
1876                let (sender, receiver) = unbounded_channel();
1877                let captured_sender = captured_sender.clone();
1878                async move {
1879                    let mut guard = captured_sender.lock().await;
1880                    *guard = Some(sender);
1881                    Ok((epoch0, receiver))
1882                }
1883                .boxed()
1884            })
1885        };
1886
1887        let (manager, (_join_handle, _stop_tx)) =
1888            SinkCoordinatorManager::start_worker_with_spawn_worker({
1889                let expected_param = param.clone();
1890                let metadata = metadata.clone();
1891                let db = db.clone();
1892                move |param, new_writer_rx| {
1893                    let metadata = metadata.clone();
1894                    let expected_param = expected_param.clone();
1895                    let db = db.clone();
1896                    tokio::spawn({
1897                        let subscriber = mock_subscriber.clone();
1898                        async move {
1899                            // validate the start request
1900                            assert_eq!(param, expected_param);
1901                            CoordinatorWorker::execute_coordinator(
1902                                db,
1903                                param.clone(),
1904                                new_writer_rx,
1905                                MockTwoPhaseCoordinator::new_coordinator(
1906                                    move |_epoch, metadata_list, _schema_change| {
1907                                        let metadata =
1908                                            metadata_list.into_iter().exactly_one().unwrap();
1909                                        Ok(match metadata.metadata {
1910                                            Some(Metadata::Serialized(SerializedMetadata {
1911                                                metadata,
1912                                            })) => Some(metadata),
1913                                            _ => unreachable!(),
1914                                        })
1915                                    },
1916                                    move |_epoch, commit_metadata| {
1917                                        assert_eq!(commit_metadata, metadata);
1918                                        Ok(())
1919                                    },
1920                                    move |_epoch, _schema_change| unreachable!(),
1921                                ),
1922                                subscriber.clone(),
1923                            )
1924                            .await;
1925                        }
1926                    })
1927                }
1928            });
1929
1930        let build_client = |vnode| async {
1931            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1932                Ok(tonic::Response::new(
1933                    manager
1934                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1935                        .await
1936                        .unwrap()
1937                        .boxed(),
1938                ))
1939            })
1940            .await
1941            .unwrap()
1942            .0
1943        };
1944
1945        let mut client = build_client(vnode).await;
1946
1947        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
1948        assert_eq!(aligned_epoch, 1);
1949
1950        client
1951            .commit(
1952                epoch1,
1953                SinkMetadata {
1954                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1955                        metadata: metadata.clone(),
1956                    })),
1957                },
1958                None,
1959            )
1960            .await
1961            .unwrap();
1962
1963        {
1964            let rows = list_rows(&db).await;
1965            assert_eq!(rows.len(), 1);
1966            assert_eq!(rows[0].1, epoch1 as i64);
1967            assert_eq!(rows[0].2, "PENDING");
1968
1969            let guard = sender.lock().await;
1970            let sender = guard.as_ref().unwrap().clone();
1971            sender.send(233).unwrap();
1972        }
1973
1974        // wait max 5 seconds for the commit to be processed
1975        for _ in 0..50 {
1976            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1977            let rows = list_rows(&db).await;
1978            if rows[0].2 == "COMMITTED" {
1979                break;
1980            }
1981        }
1982
1983        {
1984            let rows = list_rows(&db).await;
1985            assert_eq!(rows.len(), 1);
1986            assert_eq!(rows[0].1, epoch1 as i64);
1987            assert_eq!(rows[0].2, "COMMITTED");
1988        }
1989    }
1990
1991    #[tokio::test]
1992    async fn test_commit_retry_loop() {
1993        let db = prepare_db_backend().await;
1994
1995        let param = SinkParam {
1996            sink_id: SinkId::from(1),
1997            sink_name: "test".into(),
1998            properties: Default::default(),
1999            columns: vec![],
2000            downstream_pk: None,
2001            sink_type: SinkType::AppendOnly,
2002            ignore_delete: false,
2003            format_desc: None,
2004            db_name: "test".into(),
2005            sink_from_name: "test".into(),
2006        };
2007
2008        let epoch1 = 233;
2009
2010        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
2011        let build_bitmap = |indexes: &[usize]| {
2012            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
2013            for i in indexes {
2014                builder.set(*i, true);
2015            }
2016            builder.finish()
2017        };
2018        let vnode = build_bitmap(&all_vnode);
2019
2020        let metadata = vec![1u8, 2u8];
2021        let sender = Arc::new(tokio::sync::Mutex::new(None));
2022        let mock_subscriber: SinkCommittedEpochSubscriber = {
2023            let captured_sender = sender.clone();
2024            Arc::new(move |_sink_id: SinkId| {
2025                let (sender, receiver) = unbounded_channel();
2026                let captured_sender = captured_sender.clone();
2027                async move {
2028                    let mut guard = captured_sender.lock().await;
2029                    *guard = Some(sender);
2030                    Ok((epoch1, receiver))
2031                }
2032                .boxed()
2033            })
2034        };
2035
2036        let commit_attempt = Arc::new(AtomicI32::new(0));
2037
2038        let (manager, (_join_handle, _stop_tx)) =
2039            SinkCoordinatorManager::start_worker_with_spawn_worker({
2040                let expected_param = param.clone();
2041                let metadata = metadata.clone();
2042                let db = db.clone();
2043                let commit_attempt = commit_attempt.clone();
2044                move |param, new_writer_rx| {
2045                    let metadata = metadata.clone();
2046                    let expected_param = expected_param.clone();
2047                    let db = db.clone();
2048                    let commit_attempt = commit_attempt.clone();
2049                    tokio::spawn({
2050                        let subscriber = mock_subscriber.clone();
2051                        async move {
2052                            // validate the start request
2053                            assert_eq!(param, expected_param);
2054                            CoordinatorWorker::execute_coordinator(
2055                                db,
2056                                param.clone(),
2057                                new_writer_rx,
2058                                MockTwoPhaseCoordinator::new_coordinator(
2059                                    move |_epoch, metadata_list, _schema_change| {
2060                                        let metadata =
2061                                            metadata_list.into_iter().exactly_one().unwrap();
2062                                        Ok(match metadata.metadata {
2063                                            Some(Metadata::Serialized(SerializedMetadata {
2064                                                metadata,
2065                                            })) => Some(metadata),
2066                                            _ => unreachable!(),
2067                                        })
2068                                    },
2069                                    move |_epoch, commit_metadata| {
2070                                        assert_eq!(commit_metadata, metadata);
2071                                        if commit_attempt
2072                                            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
2073                                            < 2
2074                                        {
2075                                            Err(SinkError::Coordinator(anyhow!("failed to commit")))
2076                                        } else {
2077                                            Ok(())
2078                                        }
2079                                    },
2080                                    move |_epoch, _schema_change| unreachable!(),
2081                                ),
2082                                subscriber.clone(),
2083                            )
2084                            .await;
2085                        }
2086                    })
2087                }
2088            });
2089
2090        let build_client = |vnode| async {
2091            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2092                Ok(tonic::Response::new(
2093                    manager
2094                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2095                        .await
2096                        .unwrap()
2097                        .boxed(),
2098                ))
2099            })
2100            .await
2101            .unwrap()
2102            .0
2103        };
2104
2105        let mut client = build_client(vnode).await;
2106
2107        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2108        assert_eq!(aligned_epoch, 1);
2109
2110        client
2111            .commit(
2112                epoch1,
2113                SinkMetadata {
2114                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2115                        metadata: metadata.clone(),
2116                    })),
2117                },
2118                None,
2119            )
2120            .await
2121            .unwrap();
2122
2123        // wait max 10 seconds for the commit to be processed
2124        for _ in 0..100 {
2125            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2126            let rows = list_rows(&db).await;
2127            if rows[0].2 == "COMMITTED" {
2128                break;
2129            }
2130        }
2131
2132        assert_eq!(commit_attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
2133
2134        {
2135            let rows = list_rows(&db).await;
2136            assert_eq!(rows.len(), 1);
2137            assert_eq!(rows[0].1, epoch1 as i64);
2138            assert_eq!(rows[0].2, "COMMITTED");
2139        }
2140    }
2141
2142    #[tokio::test]
2143    async fn test_aborted() {
2144        let db = prepare_db_backend().await;
2145
2146        let param = SinkParam {
2147            sink_id: SinkId::from(1),
2148            sink_name: "test".into(),
2149            properties: Default::default(),
2150            columns: vec![],
2151            downstream_pk: None,
2152            sink_type: SinkType::AppendOnly,
2153            ignore_delete: false,
2154            format_desc: None,
2155            db_name: "test".into(),
2156            sink_from_name: "test".into(),
2157        };
2158
2159        let epoch0 = 232;
2160        let epoch1 = 233;
2161
2162        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
2163        let build_bitmap = |indexes: &[usize]| {
2164            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
2165            for i in indexes {
2166                builder.set(*i, true);
2167            }
2168            builder.finish()
2169        };
2170        let vnode = build_bitmap(&all_vnode);
2171
2172        let metadata = vec![1u8, 2u8];
2173
2174        let sender = Arc::new(tokio::sync::Mutex::new(None));
2175        let mock_subscriber: SinkCommittedEpochSubscriber = {
2176            let captured_sender = sender.clone();
2177            Arc::new(move |_sink_id: SinkId| {
2178                let (sender, receiver) = unbounded_channel();
2179                let captured_sender = captured_sender.clone();
2180                async move {
2181                    let mut guard = captured_sender.lock().await;
2182                    *guard = Some(sender);
2183                    Ok((epoch0, receiver))
2184                }
2185                .boxed()
2186            })
2187        };
2188
2189        let (manager, (_join_handle, _stop_tx)) =
2190            SinkCoordinatorManager::start_worker_with_spawn_worker({
2191                let expected_param = param.clone();
2192                let metadata = metadata.clone();
2193                let db = db.clone();
2194                move |param, new_writer_rx| {
2195                    let metadata = metadata.clone();
2196                    let expected_param = expected_param.clone();
2197                    let db = db.clone();
2198                    tokio::spawn({
2199                        let subscriber = mock_subscriber.clone();
2200                        async move {
2201                            // validate the start request
2202                            assert_eq!(param, expected_param);
2203                            CoordinatorWorker::execute_coordinator(
2204                                db,
2205                                param.clone(),
2206                                new_writer_rx,
2207                                MockTwoPhaseCoordinator::new_coordinator(
2208                                    move |_epoch, metadata_list, _schema_change| {
2209                                        let metadata =
2210                                            metadata_list.into_iter().exactly_one().unwrap();
2211                                        Ok(match metadata.metadata {
2212                                            Some(Metadata::Serialized(SerializedMetadata {
2213                                                metadata,
2214                                            })) => Some(metadata),
2215                                            _ => unreachable!(),
2216                                        })
2217                                    },
2218                                    move |_epoch, commit_metadata| {
2219                                        assert_eq!(commit_metadata, metadata);
2220                                        Ok(())
2221                                    },
2222                                    move |_epoch, _schema_change| unreachable!(),
2223                                ),
2224                                subscriber.clone(),
2225                            )
2226                            .await;
2227                        }
2228                    })
2229                }
2230            });
2231
2232        let build_client = |vnode| async {
2233            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2234                Ok(tonic::Response::new(
2235                    manager
2236                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2237                        .await
2238                        .unwrap()
2239                        .boxed(),
2240                ))
2241            })
2242            .await
2243            .unwrap()
2244            .0
2245        };
2246
2247        let mut client = build_client(vnode.clone()).await;
2248
2249        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2250        assert_eq!(aligned_epoch, 1);
2251
2252        client
2253            .commit(
2254                epoch1,
2255                SinkMetadata {
2256                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2257                        metadata: metadata.clone(),
2258                    })),
2259                },
2260                None,
2261            )
2262            .await
2263            .unwrap();
2264
2265        manager.stop_sink_coordinator(vec![SinkId::from(1)]).await;
2266
2267        {
2268            let rows = list_rows(&db).await;
2269            assert_eq!(rows.len(), 1);
2270            assert_eq!(rows[0].1, epoch1 as i64);
2271            assert_eq!(rows[0].2, "PENDING");
2272
2273            set_epoch_aborted(&db, SinkId::from(1), epoch1).await;
2274            let rows = list_rows(&db).await;
2275            assert_eq!(rows.len(), 1);
2276            assert_eq!(rows[0].1, epoch1 as i64);
2277            assert_eq!(rows[0].2, "ABORTED");
2278        }
2279
2280        let mut client = build_client(vnode).await;
2281
2282        let aligned_epoch = client.align_initial_epoch(1).await.unwrap();
2283        assert_eq!(aligned_epoch, 1);
2284
2285        {
2286            let rows = list_rows(&db).await;
2287            assert!(rows.is_empty());
2288        }
2289    }
2290
2291    #[tokio::test]
2292    async fn test_flush_when_reschedule() {
2293        let db = prepare_db_backend().await;
2294
2295        let param = SinkParam {
2296            sink_id: SinkId::from(1),
2297            sink_name: "test".into(),
2298            properties: Default::default(),
2299            columns: vec![],
2300            downstream_pk: None,
2301            sink_type: SinkType::AppendOnly,
2302            ignore_delete: false,
2303            format_desc: None,
2304            db_name: "test".into(),
2305            sink_from_name: "test".into(),
2306        };
2307
2308        let epoch0 = 232;
2309        let epoch1 = 233;
2310
2311        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
2312        let build_bitmap = |indexes: &[usize]| {
2313            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
2314            for i in indexes {
2315                builder.set(*i, true);
2316            }
2317            builder.finish()
2318        };
2319        let vnode = build_bitmap(&all_vnode);
2320
2321        let metadata = vec![1u8, 2u8];
2322        let schema_change = PbSinkSchemaChange {
2323            original_schema: vec![PbField {
2324                data_type: Some(PbDataType {
2325                    type_name: PbTypeName::Int32 as i32,
2326                    ..Default::default()
2327                }),
2328                name: "col_v1".into(),
2329            }],
2330            op: Some(SinkSchemachangeOp::AddColumns(PbSinkAddColumnsOp {
2331                fields: vec![PbField {
2332                    data_type: Some(PbDataType {
2333                        type_name: PbTypeName::Varchar as i32,
2334                        ..Default::default()
2335                    }),
2336                    name: "new_col".into(),
2337                }],
2338            })),
2339        };
2340
2341        let sender = Arc::new(tokio::sync::Mutex::new(None));
2342        let mock_subscriber: SinkCommittedEpochSubscriber = {
2343            let captured_sender = sender.clone();
2344            Arc::new(move |_sink_id: SinkId| {
2345                let (sender, receiver) = unbounded_channel();
2346                let captured_sender = captured_sender.clone();
2347                async move {
2348                    let mut guard = captured_sender.lock().await;
2349                    *guard = Some(sender);
2350                    Ok((epoch0, receiver))
2351                }
2352                .boxed()
2353            })
2354        };
2355
2356        let (manager, (_join_handle, _stop_tx)) =
2357            SinkCoordinatorManager::start_worker_with_spawn_worker({
2358                let expected_param = param.clone();
2359                let metadata = metadata.clone();
2360                let schema_change = schema_change.clone();
2361                let db = db.clone();
2362                move |param, new_writer_rx| {
2363                    let metadata = metadata.clone();
2364                    let schema_change_for_pre_commit = schema_change.clone();
2365                    let schema_change_for_commit = schema_change.clone();
2366                    let expected_param = expected_param.clone();
2367                    let db = db.clone();
2368                    tokio::spawn({
2369                        let subscriber = mock_subscriber.clone();
2370                        async move {
2371                            assert_eq!(param, expected_param);
2372                            CoordinatorWorker::execute_coordinator(
2373                                db,
2374                                param.clone(),
2375                                new_writer_rx,
2376                                MockTwoPhaseCoordinator::new_coordinator(
2377                                    move |_epoch, metadata_list, schema_change| {
2378                                        assert_eq!(
2379                                            schema_change,
2380                                            Some(schema_change_for_pre_commit.clone())
2381                                        );
2382                                        let metadata =
2383                                            metadata_list.into_iter().exactly_one().unwrap();
2384                                        Ok(match metadata.metadata {
2385                                            Some(Metadata::Serialized(SerializedMetadata {
2386                                                metadata,
2387                                            })) => Some(metadata),
2388                                            _ => unreachable!(),
2389                                        })
2390                                    },
2391                                    move |_epoch, commit_metadata| {
2392                                        assert_eq!(commit_metadata, metadata);
2393                                        Ok(())
2394                                    },
2395                                    move |_epoch, schema_change| {
2396                                        assert_eq!(schema_change, schema_change_for_commit.clone());
2397                                        Ok(())
2398                                    },
2399                                ),
2400                                subscriber.clone(),
2401                            )
2402                            .await;
2403                        }
2404                    })
2405                }
2406            });
2407
2408        let build_client = |vnode| async {
2409            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
2410                Ok(tonic::Response::new(
2411                    manager
2412                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
2413                        .await
2414                        .unwrap()
2415                        .boxed(),
2416                ))
2417            })
2418            .await
2419        };
2420
2421        let (mut client1, _) = build_client(vnode.clone()).await.unwrap();
2422
2423        let aligned_epoch = client1.align_initial_epoch(1).await.unwrap();
2424        assert_eq!(aligned_epoch, 1);
2425
2426        client1
2427            .commit(
2428                epoch1,
2429                SinkMetadata {
2430                    metadata: Some(Metadata::Serialized(SerializedMetadata {
2431                        metadata: metadata.clone(),
2432                    })),
2433                },
2434                Some(schema_change.clone()),
2435            )
2436            .await
2437            .unwrap();
2438
2439        {
2440            let rows = list_rows(&db).await;
2441            assert_eq!(rows.len(), 1);
2442            assert_eq!(rows[0].1, epoch1 as i64);
2443            assert_eq!(rows[0].2, "PENDING");
2444            assert_eq!(rows[0].4, Some(schema_change.clone()));
2445        }
2446
2447        let mut build_client2_future = pin!(build_client(vnode.clone()));
2448        assert!(
2449            poll_fn(|cx| Poll::Ready(build_client2_future.as_mut().poll(cx)))
2450                .await
2451                .is_pending()
2452        );
2453
2454        client1.stop().await.unwrap();
2455
2456        assert!(
2457            poll_fn(|cx| Poll::Ready(build_client2_future.as_mut().poll(cx)))
2458                .await
2459                .is_pending()
2460        );
2461
2462        {
2463            let guard = sender.lock().await;
2464            let sender = guard.as_ref().unwrap().clone();
2465            sender.send(epoch1).unwrap();
2466        }
2467
2468        let (_, init_epoch) = build_client2_future.await.unwrap();
2469        assert_eq!(init_epoch, Some(epoch1));
2470
2471        {
2472            let rows = list_rows(&db).await;
2473            assert_eq!(rows.len(), 1);
2474            assert_eq!(rows[0].1, epoch1 as i64);
2475            assert_eq!(rows[0].2, "COMMITTED");
2476            assert_eq!(rows[0].4, Some(schema_change.clone()));
2477        }
2478    }
2479}