risingwave_meta/manager/sink_coordination/
manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use futures::future::{BoxFuture, Either, select};
21use futures::stream::FuturesUnordered;
22use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
23use risingwave_common::bitmap::Bitmap;
24use risingwave_connector::sink::catalog::SinkId;
25use risingwave_connector::sink::{SinkCommittedEpochSubscriber, SinkError, SinkParam};
26use risingwave_pb::connector_service::coordinate_request::Msg;
27use risingwave_pb::connector_service::{CoordinateRequest, CoordinateResponse, coordinate_request};
28use rw_futures_util::pending_on_none;
29use sea_orm::DatabaseConnection;
30use thiserror_ext::AsReport;
31use tokio::sync::mpsc;
32use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
33use tokio::sync::oneshot::{Receiver, Sender, channel};
34use tokio::task::{JoinError, JoinHandle};
35use tokio_stream::wrappers::UnboundedReceiverStream;
36use tonic::Status;
37use tracing::{debug, error, info, warn};
38
39use crate::hummock::HummockManagerRef;
40use crate::manager::MetadataManager;
41use crate::manager::sink_coordination::SinkWriterRequestStream;
42use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
43use crate::manager::sink_coordination::handle::SinkWriterCoordinationHandle;
44
45macro_rules! send_with_err_check {
46    ($tx:expr, $msg:expr) => {
47        if $tx.send($msg).is_err() {
48            error!("unable to send msg");
49        }
50    };
51}
52
53macro_rules! send_await_with_err_check {
54    ($tx:expr, $msg:expr) => {
55        if $tx.send($msg).await.is_err() {
56            error!("unable to send msg");
57        }
58    };
59}
60
61const BOUNDED_CHANNEL_SIZE: usize = 16;
62
63enum ManagerRequest {
64    NewSinkWriter(SinkWriterCoordinationHandle),
65    StopCoordinator {
66        finish_notifier: Sender<()>,
67        /// sink id to stop. When `None`, stop all sink coordinator
68        sink_id: Option<SinkId>,
69    },
70}
71
72#[derive(Clone)]
73pub struct SinkCoordinatorManager {
74    request_tx: mpsc::Sender<ManagerRequest>,
75}
76
77fn new_committed_epoch_subscriber(
78    hummock_manager: HummockManagerRef,
79    metadata_manager: MetadataManager,
80) -> SinkCommittedEpochSubscriber {
81    Arc::new(move |sink_id| {
82        let hummock_manager = hummock_manager.clone();
83        let metadata_manager = metadata_manager.clone();
84        async move {
85            let state_table_ids = metadata_manager
86                .get_sink_state_table_ids(sink_id.sink_id as _)
87                .await
88                .map_err(SinkError::from)?;
89            let Some(table_id) = state_table_ids.first() else {
90                return Err(anyhow!("no state table id in sink: {}", sink_id).into());
91            };
92            hummock_manager
93                .subscribe_table_committed_epoch(*table_id)
94                .await
95                .map_err(SinkError::from)
96        }
97        .boxed()
98    })
99}
100
101impl SinkCoordinatorManager {
102    pub fn start_worker(
103        db: DatabaseConnection,
104        hummock_manager: HummockManagerRef,
105        metadata_manager: MetadataManager,
106    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
107        let subscriber =
108            new_committed_epoch_subscriber(hummock_manager.clone(), metadata_manager.clone());
109        Self::start_worker_with_spawn_worker(move |param, manager_request_stream| {
110            tokio::spawn(CoordinatorWorker::run(
111                param,
112                manager_request_stream,
113                db.clone(),
114                subscriber.clone(),
115            ))
116        })
117    }
118
119    fn start_worker_with_spawn_worker(
120        spawn_coordinator_worker: impl SpawnCoordinatorFn,
121    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
122        let (request_tx, request_rx) = mpsc::channel(BOUNDED_CHANNEL_SIZE);
123        let (shutdown_tx, shutdown_rx) = channel();
124        let worker = ManagerWorker::new(request_rx, shutdown_rx);
125        let join_handle = tokio::spawn(worker.execute(spawn_coordinator_worker));
126        (
127            SinkCoordinatorManager { request_tx },
128            (join_handle, shutdown_tx),
129        )
130    }
131
132    pub async fn handle_new_request(
133        &self,
134        mut request_stream: SinkWriterRequestStream,
135    ) -> Result<impl Stream<Item = Result<CoordinateResponse, Status>> + use<>, Status> {
136        let (param, vnode_bitmap) = match request_stream.try_next().await? {
137            Some(CoordinateRequest {
138                msg:
139                    Some(Msg::StartRequest(coordinate_request::StartCoordinationRequest {
140                        param: Some(param),
141                        vnode_bitmap: Some(vnode_bitmap),
142                    })),
143            }) => (SinkParam::from_proto(param), Bitmap::from(&vnode_bitmap)),
144            msg => {
145                return Err(Status::invalid_argument(format!(
146                    "expected CoordinateRequest::StartRequest in the first request, get {:?}",
147                    msg
148                )));
149            }
150        };
151        let (response_tx, response_rx) = mpsc::unbounded_channel();
152        self.request_tx
153            .send(ManagerRequest::NewSinkWriter(
154                SinkWriterCoordinationHandle::new(request_stream, response_tx, param, vnode_bitmap),
155            ))
156            .await
157            .map_err(|_| {
158                Status::unavailable(
159                    "unable to send to sink manager worker. The worker may have stopped",
160                )
161            })?;
162
163        Ok(UnboundedReceiverStream::new(response_rx))
164    }
165
166    async fn stop_coordinator(&self, sink_id: Option<SinkId>) {
167        let (tx, rx) = channel();
168        send_await_with_err_check!(
169            self.request_tx,
170            ManagerRequest::StopCoordinator {
171                finish_notifier: tx,
172                sink_id,
173            }
174        );
175        if rx.await.is_err() {
176            error!("fail to wait for resetting sink manager worker");
177        }
178        info!("successfully stop coordinator: {:?}", sink_id);
179    }
180
181    pub async fn reset(&self) {
182        self.stop_coordinator(None).await;
183    }
184
185    pub async fn stop_sink_coordinator(&self, sink_id: SinkId) {
186        self.stop_coordinator(Some(sink_id)).await;
187    }
188}
189
190struct CoordinatorWorkerHandle {
191    /// Sender to coordinator worker. Drop the sender as a stop signal
192    request_sender: Option<UnboundedSender<SinkWriterCoordinationHandle>>,
193    /// Notify when the coordinator worker stops
194    finish_notifiers: Vec<Sender<()>>,
195}
196
197struct ManagerWorker {
198    request_rx: mpsc::Receiver<ManagerRequest>,
199    // Make it option so that it can be polled with &mut SinkManagerWorker
200    shutdown_rx: Receiver<()>,
201
202    running_coordinator_worker_join_handles:
203        FuturesUnordered<BoxFuture<'static, (SinkId, Result<(), JoinError>)>>,
204    running_coordinator_worker: HashMap<SinkId, CoordinatorWorkerHandle>,
205}
206
207enum ManagerEvent {
208    NewRequest(ManagerRequest),
209    CoordinatorWorkerFinished {
210        sink_id: SinkId,
211        join_result: Result<(), JoinError>,
212    },
213}
214
215trait SpawnCoordinatorFn = FnMut(SinkParam, UnboundedReceiver<SinkWriterCoordinationHandle>) -> JoinHandle<()>
216    + Send
217    + 'static;
218
219impl ManagerWorker {
220    fn new(request_rx: mpsc::Receiver<ManagerRequest>, shutdown_rx: Receiver<()>) -> Self {
221        ManagerWorker {
222            request_rx,
223            shutdown_rx,
224            running_coordinator_worker_join_handles: Default::default(),
225            running_coordinator_worker: Default::default(),
226        }
227    }
228
229    async fn execute(mut self, mut spawn_coordinator_worker: impl SpawnCoordinatorFn) {
230        while let Some(event) = self.next_event().await {
231            match event {
232                ManagerEvent::NewRequest(request) => match request {
233                    ManagerRequest::NewSinkWriter(request) => {
234                        self.handle_new_sink_writer(request, &mut spawn_coordinator_worker)
235                    }
236                    ManagerRequest::StopCoordinator {
237                        finish_notifier,
238                        sink_id,
239                    } => {
240                        if let Some(sink_id) = sink_id {
241                            if let Some(worker_handle) =
242                                self.running_coordinator_worker.get_mut(&sink_id)
243                            {
244                                if let Some(sender) = worker_handle.request_sender.take() {
245                                    // drop the sender as a signal to notify the coordinator worker
246                                    // to stop
247                                    drop(sender);
248                                }
249                                worker_handle.finish_notifiers.push(finish_notifier);
250                            } else {
251                                debug!(
252                                    "sink coordinator of {} is not running. Notify finish directly",
253                                    sink_id.sink_id
254                                );
255                                send_with_err_check!(finish_notifier, ());
256                            }
257                        } else {
258                            self.clean_up().await;
259                            send_with_err_check!(finish_notifier, ());
260                        }
261                    }
262                },
263                ManagerEvent::CoordinatorWorkerFinished {
264                    sink_id,
265                    join_result,
266                } => self.handle_coordinator_finished(sink_id, join_result),
267            }
268        }
269        self.clean_up().await;
270        info!("sink manager worker exited");
271    }
272
273    async fn next_event(&mut self) -> Option<ManagerEvent> {
274        match select(
275            select(
276                pin!(self.request_rx.recv()),
277                pin!(pending_on_none(
278                    self.running_coordinator_worker_join_handles.next()
279                )),
280            ),
281            &mut self.shutdown_rx,
282        )
283        .await
284        {
285            Either::Left((either, _)) => match either {
286                Either::Left((Some(request), _)) => Some(ManagerEvent::NewRequest(request)),
287                Either::Left((None, _)) => None,
288                Either::Right(((sink_id, join_result), _)) => {
289                    Some(ManagerEvent::CoordinatorWorkerFinished {
290                        sink_id,
291                        join_result,
292                    })
293                }
294            },
295            Either::Right(_) => None,
296        }
297    }
298
299    async fn clean_up(&mut self) {
300        info!("sink manager worker start cleaning up");
301        for worker_handle in self.running_coordinator_worker.values_mut() {
302            if let Some(sender) = worker_handle.request_sender.take() {
303                // drop the sender to notify the coordinator worker to stop
304                drop(sender);
305            }
306        }
307        while let Some((sink_id, join_result)) =
308            self.running_coordinator_worker_join_handles.next().await
309        {
310            self.handle_coordinator_finished(sink_id, join_result);
311        }
312        info!("sink manager worker finished cleaning up");
313    }
314
315    fn handle_coordinator_finished(&mut self, sink_id: SinkId, join_result: Result<(), JoinError>) {
316        let worker_handle = self
317            .running_coordinator_worker
318            .remove(&sink_id)
319            .expect("finished coordinator should have an associated worker handle");
320        for finish_notifier in worker_handle.finish_notifiers {
321            send_with_err_check!(finish_notifier, ());
322        }
323        match join_result {
324            Ok(()) => {
325                info!(
326                    id = sink_id.sink_id,
327                    "sink coordinator has gracefully finished",
328                );
329            }
330            Err(err) => {
331                error!(
332                    id = sink_id.sink_id,
333                    error = %err.as_report(),
334                    "sink coordinator finished with error",
335                );
336            }
337        }
338    }
339
340    fn handle_new_sink_writer(
341        &mut self,
342        new_writer: SinkWriterCoordinationHandle,
343        spawn_coordinator_worker: &mut impl SpawnCoordinatorFn,
344    ) {
345        let param = new_writer.param();
346        let sink_id = param.sink_id;
347
348        let handle = self
349            .running_coordinator_worker
350            .entry(param.sink_id)
351            .or_insert_with(|| {
352                // Launch the coordinator worker task if it is the first
353                let (request_tx, request_rx) = unbounded_channel();
354                let join_handle = spawn_coordinator_worker(param.clone(), request_rx);
355                self.running_coordinator_worker_join_handles.push(
356                    join_handle
357                        .map(move |join_result| (sink_id, join_result))
358                        .boxed(),
359                );
360                CoordinatorWorkerHandle {
361                    request_sender: Some(request_tx),
362                    finish_notifiers: Vec::new(),
363                }
364            });
365
366        if let Some(sender) = handle.request_sender.as_mut() {
367            send_with_err_check!(sender, new_writer);
368        } else {
369            warn!(
370                "handle a new request while the sink coordinator is being stopped: {:?}",
371                param
372            );
373            new_writer.abort(Status::internal("the sink is being stopped"));
374        }
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use std::future::{Future, poll_fn};
381    use std::pin::pin;
382    use std::sync::Arc;
383    use std::task::Poll;
384
385    use anyhow::anyhow;
386    use async_trait::async_trait;
387    use futures::future::{join, try_join};
388    use futures::{FutureExt, StreamExt, TryFutureExt};
389    use itertools::Itertools;
390    use rand::seq::SliceRandom;
391    use risingwave_common::bitmap::BitmapBuilder;
392    use risingwave_common::hash::VirtualNode;
393    use risingwave_connector::sink::catalog::{SinkId, SinkType};
394    use risingwave_connector::sink::{SinkCommitCoordinator, SinkError, SinkParam};
395    use risingwave_pb::connector_service::SinkMetadata;
396    use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata};
397    use risingwave_rpc_client::CoordinatorStreamHandle;
398    use tokio::sync::mpsc::unbounded_channel;
399    use tokio_stream::wrappers::ReceiverStream;
400
401    use crate::manager::sink_coordination::SinkCoordinatorManager;
402    use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
403    use crate::manager::sink_coordination::manager::SinkCommittedEpochSubscriber;
404
405    struct MockCoordinator<C, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>> {
406        context: C,
407        f: F,
408    }
409
410    impl<C, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>> MockCoordinator<C, F> {
411        fn new(context: C, f: F) -> Self {
412            MockCoordinator { context, f }
413        }
414    }
415
416    #[async_trait]
417    impl<C: Send, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send>
418        SinkCommitCoordinator for MockCoordinator<C, F>
419    {
420        async fn init(
421            &mut self,
422            _subscriber: SinkCommittedEpochSubscriber,
423        ) -> risingwave_connector::sink::Result<Option<u64>> {
424            Ok(None)
425        }
426
427        async fn commit(
428            &mut self,
429            epoch: u64,
430            metadata: Vec<SinkMetadata>,
431        ) -> risingwave_connector::sink::Result<()> {
432            (self.f)(epoch, metadata, &mut self.context)
433        }
434    }
435
436    #[tokio::test]
437    async fn test_basic() {
438        let param = SinkParam {
439            sink_id: SinkId::from(1),
440            sink_name: "test".into(),
441            properties: Default::default(),
442            columns: vec![],
443            downstream_pk: vec![],
444            sink_type: SinkType::AppendOnly,
445            format_desc: None,
446            db_name: "test".into(),
447            sink_from_name: "test".into(),
448        };
449
450        let epoch0 = 232;
451        let epoch1 = 233;
452        let epoch2 = 234;
453
454        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
455        all_vnode.shuffle(&mut rand::rng());
456        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
457        let build_bitmap = |indexes: &[usize]| {
458            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
459            for i in indexes {
460                builder.set(*i, true);
461            }
462            builder.finish()
463        };
464        let vnode1 = build_bitmap(first);
465        let vnode2 = build_bitmap(second);
466
467        let metadata = [
468            [vec![1u8, 2u8], vec![3u8, 4u8]],
469            [vec![5u8, 6u8], vec![7u8, 8u8]],
470        ];
471        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
472            let (_sender, receiver) = unbounded_channel();
473
474            async move { Ok((1, receiver)) }.boxed()
475        });
476
477        let (manager, (_join_handle, _stop_tx)) =
478            SinkCoordinatorManager::start_worker_with_spawn_worker({
479                let expected_param = param.clone();
480                let metadata = metadata.clone();
481                move |param, new_writer_rx| {
482                    let metadata = metadata.clone();
483                    let expected_param = expected_param.clone();
484                    tokio::spawn({
485                        let subscriber = mock_subscriber.clone();
486                        async move {
487                            // validate the start request
488                            assert_eq!(param, expected_param);
489                            CoordinatorWorker::execute_coordinator(
490                                param.clone(),
491                                new_writer_rx,
492                                MockCoordinator::new(
493                                    0,
494                                    |epoch, metadata_list, count: &mut usize| {
495                                        *count += 1;
496                                        let mut metadata_list =
497                                            metadata_list
498                                                .into_iter()
499                                                .map(|metadata| match metadata {
500                                                    SinkMetadata {
501                                                        metadata:
502                                                            Some(Metadata::Serialized(
503                                                                SerializedMetadata { metadata },
504                                                            )),
505                                                    } => metadata,
506                                                    _ => unreachable!(),
507                                                })
508                                                .collect_vec();
509                                        metadata_list.sort();
510                                        match *count {
511                                            1 => {
512                                                assert_eq!(epoch, epoch1);
513                                                assert_eq!(2, metadata_list.len());
514                                                assert_eq!(metadata[0][0], metadata_list[0]);
515                                                assert_eq!(metadata[0][1], metadata_list[1]);
516                                            }
517                                            2 => {
518                                                assert_eq!(epoch, epoch2);
519                                                assert_eq!(2, metadata_list.len());
520                                                assert_eq!(metadata[1][0], metadata_list[0]);
521                                                assert_eq!(metadata[1][1], metadata_list[1]);
522                                            }
523                                            _ => unreachable!(),
524                                        }
525                                        Ok(())
526                                    },
527                                ),
528                                subscriber.clone(),
529                            )
530                            .await;
531                        }
532                    })
533                }
534            });
535
536        let build_client = |vnode| async {
537            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
538                Ok(tonic::Response::new(
539                    manager
540                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
541                        .await
542                        .unwrap()
543                        .boxed(),
544                ))
545            })
546            .await
547            .unwrap()
548            .0
549        };
550
551        let (mut client1, mut client2) =
552            join(build_client(vnode1), pin!(build_client(vnode2))).await;
553
554        let (aligned_epoch1, aligned_epoch2) = try_join(
555            client1.align_initial_epoch(epoch0),
556            client2.align_initial_epoch(epoch1),
557        )
558        .await
559        .unwrap();
560        assert_eq!(aligned_epoch1, epoch1);
561        assert_eq!(aligned_epoch2, epoch1);
562
563        {
564            // commit epoch1
565            let mut commit_future = pin!(
566                client2
567                    .commit(
568                        epoch1,
569                        SinkMetadata {
570                            metadata: Some(Metadata::Serialized(SerializedMetadata {
571                                metadata: metadata[0][1].clone(),
572                            })),
573                        },
574                    )
575                    .map(|result| result.unwrap())
576            );
577            assert!(
578                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
579                    .await
580                    .is_pending()
581            );
582            join(
583                commit_future,
584                client1
585                    .commit(
586                        epoch1,
587                        SinkMetadata {
588                            metadata: Some(Metadata::Serialized(SerializedMetadata {
589                                metadata: metadata[0][0].clone(),
590                            })),
591                        },
592                    )
593                    .map(|result| result.unwrap()),
594            )
595            .await;
596        }
597
598        // commit epoch2
599        let mut commit_future = pin!(
600            client1
601                .commit(
602                    epoch2,
603                    SinkMetadata {
604                        metadata: Some(Metadata::Serialized(SerializedMetadata {
605                            metadata: metadata[1][0].clone(),
606                        })),
607                    },
608                )
609                .map(|result| result.unwrap())
610        );
611        assert!(
612            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
613                .await
614                .is_pending()
615        );
616        join(
617            commit_future,
618            client2
619                .commit(
620                    epoch2,
621                    SinkMetadata {
622                        metadata: Some(Metadata::Serialized(SerializedMetadata {
623                            metadata: metadata[1][1].clone(),
624                        })),
625                    },
626                )
627                .map(|result| result.unwrap()),
628        )
629        .await;
630    }
631
632    #[tokio::test]
633    async fn test_single_writer() {
634        let param = SinkParam {
635            sink_id: SinkId::from(1),
636            sink_name: "test".into(),
637            properties: Default::default(),
638            columns: vec![],
639            downstream_pk: vec![],
640            sink_type: SinkType::AppendOnly,
641            format_desc: None,
642            db_name: "test".into(),
643            sink_from_name: "test".into(),
644        };
645
646        let epoch1 = 233;
647        let epoch2 = 234;
648
649        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
650        let build_bitmap = |indexes: &[usize]| {
651            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
652            for i in indexes {
653                builder.set(*i, true);
654            }
655            builder.finish()
656        };
657        let vnode = build_bitmap(&all_vnode);
658
659        let metadata = [vec![1u8, 2u8], vec![3u8, 4u8]];
660        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
661            let (_sender, receiver) = unbounded_channel();
662
663            async move { Ok((1, receiver)) }.boxed()
664        });
665        let (manager, (_join_handle, _stop_tx)) =
666            SinkCoordinatorManager::start_worker_with_spawn_worker({
667                let expected_param = param.clone();
668                let metadata = metadata.clone();
669                move |param, new_writer_rx| {
670                    let metadata = metadata.clone();
671                    let expected_param = expected_param.clone();
672                    tokio::spawn({
673                        let subscriber = mock_subscriber.clone();
674                        async move {
675                            // validate the start request
676                            assert_eq!(param, expected_param);
677                            CoordinatorWorker::execute_coordinator(
678                                param.clone(),
679                                new_writer_rx,
680                                MockCoordinator::new(
681                                    0,
682                                    |epoch, metadata_list, count: &mut usize| {
683                                        *count += 1;
684                                        let mut metadata_list =
685                                            metadata_list
686                                                .into_iter()
687                                                .map(|metadata| match metadata {
688                                                    SinkMetadata {
689                                                        metadata:
690                                                            Some(Metadata::Serialized(
691                                                                SerializedMetadata { metadata },
692                                                            )),
693                                                    } => metadata,
694                                                    _ => unreachable!(),
695                                                })
696                                                .collect_vec();
697                                        metadata_list.sort();
698                                        match *count {
699                                            1 => {
700                                                assert_eq!(epoch, epoch1);
701                                                assert_eq!(1, metadata_list.len());
702                                                assert_eq!(metadata[0], metadata_list[0]);
703                                            }
704                                            2 => {
705                                                assert_eq!(epoch, epoch2);
706                                                assert_eq!(1, metadata_list.len());
707                                                assert_eq!(metadata[1], metadata_list[0]);
708                                            }
709                                            _ => unreachable!(),
710                                        }
711                                        Ok(())
712                                    },
713                                ),
714                                subscriber.clone(),
715                            )
716                            .await;
717                        }
718                    })
719                }
720            });
721
722        let build_client = |vnode| async {
723            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
724                Ok(tonic::Response::new(
725                    manager
726                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
727                        .await
728                        .unwrap()
729                        .boxed(),
730                ))
731            })
732            .await
733            .unwrap()
734            .0
735        };
736
737        let mut client = build_client(vnode).await;
738
739        let aligned_epoch = client.align_initial_epoch(epoch1).await.unwrap();
740        assert_eq!(aligned_epoch, epoch1);
741
742        client
743            .commit(
744                epoch1,
745                SinkMetadata {
746                    metadata: Some(Metadata::Serialized(SerializedMetadata {
747                        metadata: metadata[0].clone(),
748                    })),
749                },
750            )
751            .await
752            .unwrap();
753
754        client
755            .commit(
756                epoch2,
757                SinkMetadata {
758                    metadata: Some(Metadata::Serialized(SerializedMetadata {
759                        metadata: metadata[1].clone(),
760                    })),
761                },
762            )
763            .await
764            .unwrap();
765    }
766
767    #[tokio::test]
768    async fn test_partial_commit() {
769        let param = SinkParam {
770            sink_id: SinkId::from(1),
771            sink_name: "test".into(),
772            properties: Default::default(),
773            columns: vec![],
774            downstream_pk: vec![],
775            sink_type: SinkType::AppendOnly,
776            format_desc: None,
777            db_name: "test".into(),
778            sink_from_name: "test".into(),
779        };
780
781        let epoch = 233;
782
783        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
784        all_vnode.shuffle(&mut rand::rng());
785        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
786        let build_bitmap = |indexes: &[usize]| {
787            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
788            for i in indexes {
789                builder.set(*i, true);
790            }
791            builder.finish()
792        };
793        let vnode1 = build_bitmap(first);
794        let vnode2 = build_bitmap(second);
795
796        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
797            let (_sender, receiver) = unbounded_channel();
798
799            async move { Ok((1, receiver)) }.boxed()
800        });
801        let (manager, (_join_handle, _stop_tx)) =
802            SinkCoordinatorManager::start_worker_with_spawn_worker({
803                let expected_param = param.clone();
804                move |param, new_writer_rx| {
805                    let expected_param = expected_param.clone();
806                    tokio::spawn({
807                        let subscriber = mock_subscriber.clone();
808                        async move {
809                            // validate the start request
810                            assert_eq!(param, expected_param);
811                            CoordinatorWorker::execute_coordinator(
812                                param,
813                                new_writer_rx,
814                                MockCoordinator::new((), |_, _, _| unreachable!()),
815                                subscriber.clone(),
816                            )
817                            .await;
818                        }
819                    })
820                }
821            });
822
823        let build_client = |vnode| async {
824            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
825                Ok(tonic::Response::new(
826                    manager
827                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
828                        .await
829                        .unwrap()
830                        .boxed(),
831                ))
832            })
833            .await
834            .unwrap()
835            .0
836        };
837
838        let (mut client1, client2) = join(build_client(vnode1), build_client(vnode2)).await;
839
840        // commit epoch
841        let mut commit_future = pin!(client1.commit(
842            epoch,
843            SinkMetadata {
844                metadata: Some(Metadata::Serialized(SerializedMetadata {
845                    metadata: vec![],
846                })),
847            },
848        ));
849        assert!(
850            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
851                .await
852                .is_pending()
853        );
854        drop(client2);
855        assert!(commit_future.await.is_err());
856    }
857
858    #[tokio::test]
859    async fn test_fail_commit() {
860        let param = SinkParam {
861            sink_id: SinkId::from(1),
862            sink_name: "test".into(),
863            properties: Default::default(),
864            columns: vec![],
865            downstream_pk: vec![],
866            sink_type: SinkType::AppendOnly,
867            format_desc: None,
868            db_name: "test".into(),
869            sink_from_name: "test".into(),
870        };
871
872        let epoch = 233;
873
874        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
875        all_vnode.shuffle(&mut rand::rng());
876        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
877        let build_bitmap = |indexes: &[usize]| {
878            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
879            for i in indexes {
880                builder.set(*i, true);
881            }
882            builder.finish()
883        };
884        let vnode1 = build_bitmap(first);
885        let vnode2 = build_bitmap(second);
886        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
887            let (_sender, receiver) = unbounded_channel();
888
889            async move { Ok((1, receiver)) }.boxed()
890        });
891        let (manager, (_join_handle, _stop_tx)) =
892            SinkCoordinatorManager::start_worker_with_spawn_worker({
893                let expected_param = param.clone();
894                move |param, new_writer_rx| {
895                    let expected_param = expected_param.clone();
896                    tokio::spawn({
897                        let subscriber = mock_subscriber.clone();
898                        {
899                            async move {
900                                // validate the start request
901                                assert_eq!(param, expected_param);
902                                CoordinatorWorker::execute_coordinator(
903                                    param,
904                                    new_writer_rx,
905                                    MockCoordinator::new((), |_, _, _| {
906                                        Err(SinkError::Coordinator(anyhow!("failed to commit")))
907                                    }),
908                                    subscriber.clone(),
909                                )
910                                .await;
911                            }
912                        }
913                    })
914                }
915            });
916
917        let build_client = |vnode| async {
918            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
919                Ok(tonic::Response::new(
920                    manager
921                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
922                        .await
923                        .unwrap()
924                        .boxed(),
925                ))
926            })
927            .await
928            .unwrap()
929            .0
930        };
931
932        let (mut client1, mut client2) = join(build_client(vnode1), build_client(vnode2)).await;
933
934        // commit epoch
935        let mut commit_future = pin!(client1.commit(
936            epoch,
937            SinkMetadata {
938                metadata: Some(Metadata::Serialized(SerializedMetadata {
939                    metadata: vec![],
940                })),
941            },
942        ));
943        assert!(
944            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
945                .await
946                .is_pending()
947        );
948        let (result1, result2) = join(
949            commit_future,
950            client2.commit(
951                epoch,
952                SinkMetadata {
953                    metadata: Some(Metadata::Serialized(SerializedMetadata {
954                        metadata: vec![],
955                    })),
956                },
957            ),
958        )
959        .await;
960        assert!(result1.is_err());
961        assert!(result2.is_err());
962    }
963
964    #[tokio::test]
965    async fn test_update_vnode_bitmap() {
966        let param = SinkParam {
967            sink_id: SinkId::from(1),
968            sink_name: "test".into(),
969            properties: Default::default(),
970            columns: vec![],
971            downstream_pk: vec![],
972            sink_type: SinkType::AppendOnly,
973            format_desc: None,
974            db_name: "test".into(),
975            sink_from_name: "test".into(),
976        };
977
978        let epoch1 = 233;
979        let epoch2 = 234;
980        let epoch3 = 235;
981        let epoch4 = 236;
982
983        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
984        all_vnode.shuffle(&mut rand::rng());
985        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
986        let build_bitmap = |indexes: &[usize]| {
987            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
988            for i in indexes {
989                builder.set(*i, true);
990            }
991            builder.finish()
992        };
993        let vnode1 = build_bitmap(first);
994        let vnode2 = build_bitmap(second);
995
996        let metadata = [
997            [vec![1u8, 2u8], vec![3u8, 4u8]],
998            [vec![5u8, 6u8], vec![7u8, 8u8]],
999        ];
1000
1001        let metadata_scale_out = [vec![9u8, 10u8], vec![11u8, 12u8], vec![13u8, 14u8]];
1002        let metadata_scale_in = [vec![13u8, 14u8], vec![15u8, 16u8]];
1003        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
1004            let (_sender, receiver) = unbounded_channel();
1005
1006            async move { Ok((1, receiver)) }.boxed()
1007        });
1008        let (manager, (_join_handle, _stop_tx)) =
1009            SinkCoordinatorManager::start_worker_with_spawn_worker({
1010                let expected_param = param.clone();
1011                let metadata = metadata.clone();
1012                let metadata_scale_out = metadata_scale_out.clone();
1013                let metadata_scale_in = metadata_scale_in.clone();
1014                move |param, new_writer_rx| {
1015                    let metadata = metadata.clone();
1016                    let metadata_scale_out = metadata_scale_out.clone();
1017                    let metadata_scale_in = metadata_scale_in.clone();
1018                    let expected_param = expected_param.clone();
1019                    tokio::spawn({
1020                        let subscriber = mock_subscriber.clone();
1021                        async move {
1022                            // validate the start request
1023                            assert_eq!(param, expected_param);
1024                            CoordinatorWorker::execute_coordinator(
1025                                param.clone(),
1026                                new_writer_rx,
1027                                MockCoordinator::new(
1028                                    0,
1029                                    |epoch, metadata_list, count: &mut usize| {
1030                                        *count += 1;
1031                                        let mut metadata_list =
1032                                            metadata_list
1033                                                .into_iter()
1034                                                .map(|metadata| match metadata {
1035                                                    SinkMetadata {
1036                                                        metadata:
1037                                                            Some(Metadata::Serialized(
1038                                                                SerializedMetadata { metadata },
1039                                                            )),
1040                                                    } => metadata,
1041                                                    _ => unreachable!(),
1042                                                })
1043                                                .collect_vec();
1044                                        metadata_list.sort();
1045                                        let (expected_epoch, expected_metadata_list) = match *count
1046                                        {
1047                                            1 => (epoch1, metadata[0].as_slice()),
1048                                            2 => (epoch2, metadata[1].as_slice()),
1049                                            3 => (epoch3, metadata_scale_out.as_slice()),
1050                                            4 => (epoch4, metadata_scale_in.as_slice()),
1051                                            _ => unreachable!(),
1052                                        };
1053                                        assert_eq!(expected_epoch, epoch);
1054                                        assert_eq!(expected_metadata_list, &metadata_list);
1055                                        Ok(())
1056                                    },
1057                                ),
1058                                subscriber.clone(),
1059                            )
1060                            .await;
1061                        }
1062                    })
1063                }
1064            });
1065
1066        let build_client = |vnode| async {
1067            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1068                Ok(tonic::Response::new(
1069                    manager
1070                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1071                        .await
1072                        .unwrap()
1073                        .boxed(),
1074                ))
1075            })
1076            .await
1077        };
1078
1079        let ((mut client1, _), (mut client2, _)) =
1080            try_join(build_client(vnode1), pin!(build_client(vnode2)))
1081                .await
1082                .unwrap();
1083
1084        let (aligned_epoch1, aligned_epoch2) = try_join(
1085            client1.align_initial_epoch(epoch1),
1086            client2.align_initial_epoch(epoch1),
1087        )
1088        .await
1089        .unwrap();
1090        assert_eq!(aligned_epoch1, epoch1);
1091        assert_eq!(aligned_epoch2, epoch1);
1092
1093        {
1094            // commit epoch1
1095            let mut commit_future = pin!(
1096                client2
1097                    .commit(
1098                        epoch1,
1099                        SinkMetadata {
1100                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1101                                metadata: metadata[0][1].clone(),
1102                            })),
1103                        },
1104                    )
1105                    .map(|result| result.unwrap())
1106            );
1107            assert!(
1108                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1109                    .await
1110                    .is_pending()
1111            );
1112            join(
1113                commit_future,
1114                client1
1115                    .commit(
1116                        epoch1,
1117                        SinkMetadata {
1118                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1119                                metadata: metadata[0][0].clone(),
1120                            })),
1121                        },
1122                    )
1123                    .map(|result| result.unwrap()),
1124            )
1125            .await;
1126        }
1127
1128        let (vnode1, vnode2, vnode3) = {
1129            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1130            let (second, third) = second.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1131            (
1132                build_bitmap(first),
1133                build_bitmap(second),
1134                build_bitmap(third),
1135            )
1136        };
1137
1138        let mut build_client3_future = pin!(build_client(vnode3));
1139        assert!(
1140            poll_fn(|cx| Poll::Ready(build_client3_future.as_mut().poll(cx)))
1141                .await
1142                .is_pending()
1143        );
1144        let mut client3;
1145        {
1146            {
1147                // commit epoch2
1148                let mut commit_future = pin!(
1149                    client1
1150                        .commit(
1151                            epoch2,
1152                            SinkMetadata {
1153                                metadata: Some(Metadata::Serialized(SerializedMetadata {
1154                                    metadata: metadata[1][0].clone(),
1155                                })),
1156                            },
1157                        )
1158                        .map_err(Into::into)
1159                );
1160                assert!(
1161                    poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1162                        .await
1163                        .is_pending()
1164                );
1165                try_join(
1166                    commit_future,
1167                    client2.commit(
1168                        epoch2,
1169                        SinkMetadata {
1170                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1171                                metadata: metadata[1][1].clone(),
1172                            })),
1173                        },
1174                    ),
1175                )
1176                .await
1177                .unwrap();
1178            }
1179
1180            client3 = {
1181                let (
1182                    (client3, init_epoch),
1183                    (update_vnode_bitmap_epoch1, update_vnode_bitmap_epoch2),
1184                ) = try_join(
1185                    build_client3_future,
1186                    try_join(
1187                        client1.update_vnode_bitmap(&vnode1),
1188                        client2.update_vnode_bitmap(&vnode2),
1189                    )
1190                    .map_err(Into::into),
1191                )
1192                .await
1193                .unwrap();
1194                assert_eq!(init_epoch, Some(epoch2));
1195                assert_eq!(update_vnode_bitmap_epoch1, epoch2);
1196                assert_eq!(update_vnode_bitmap_epoch2, epoch2);
1197                client3
1198            };
1199            let mut commit_future3 = pin!(client3.commit(
1200                epoch3,
1201                SinkMetadata {
1202                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1203                        metadata: metadata_scale_out[2].clone(),
1204                    })),
1205                },
1206            ));
1207            assert!(
1208                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1209                    .await
1210                    .is_pending()
1211            );
1212            let mut commit_future1 = pin!(client1.commit(
1213                epoch3,
1214                SinkMetadata {
1215                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1216                        metadata: metadata_scale_out[0].clone(),
1217                    })),
1218                },
1219            ));
1220            assert!(
1221                poll_fn(|cx| Poll::Ready(commit_future1.as_mut().poll(cx)))
1222                    .await
1223                    .is_pending()
1224            );
1225            assert!(
1226                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1227                    .await
1228                    .is_pending()
1229            );
1230            try_join(
1231                client2.commit(
1232                    epoch3,
1233                    SinkMetadata {
1234                        metadata: Some(Metadata::Serialized(SerializedMetadata {
1235                            metadata: metadata_scale_out[1].clone(),
1236                        })),
1237                    },
1238                ),
1239                try_join(commit_future1, commit_future3),
1240            )
1241            .await
1242            .unwrap();
1243        }
1244
1245        let (vnode2, vnode3) = {
1246            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1247            (build_bitmap(first), build_bitmap(second))
1248        };
1249
1250        {
1251            let (_, (update_vnode_bitmap_epoch2, update_vnode_bitmap_epoch3)) = try_join(
1252                client1.stop(),
1253                try_join(
1254                    client2.update_vnode_bitmap(&vnode2),
1255                    client3.update_vnode_bitmap(&vnode3),
1256                ),
1257            )
1258            .await
1259            .unwrap();
1260            assert_eq!(update_vnode_bitmap_epoch2, epoch3);
1261            assert_eq!(update_vnode_bitmap_epoch3, epoch3);
1262        }
1263
1264        {
1265            let mut commit_future = pin!(
1266                client2
1267                    .commit(
1268                        epoch4,
1269                        SinkMetadata {
1270                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1271                                metadata: metadata_scale_in[0].clone(),
1272                            })),
1273                        },
1274                    )
1275                    .map(|result| result.unwrap())
1276            );
1277            assert!(
1278                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1279                    .await
1280                    .is_pending()
1281            );
1282            join(
1283                commit_future,
1284                client3
1285                    .commit(
1286                        epoch4,
1287                        SinkMetadata {
1288                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1289                                metadata: metadata_scale_in[1].clone(),
1290                            })),
1291                        },
1292                    )
1293                    .map(|result| result.unwrap()),
1294            )
1295            .await;
1296        }
1297    }
1298}