risingwave_meta/manager/sink_coordination/
manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use futures::future::{BoxFuture, Either, select};
21use futures::stream::FuturesUnordered;
22use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
23use risingwave_common::bitmap::Bitmap;
24use risingwave_connector::connector_common::IcebergSinkCompactionUpdate;
25use risingwave_connector::sink::catalog::SinkId;
26use risingwave_connector::sink::{SinkCommittedEpochSubscriber, SinkError, SinkParam};
27use risingwave_pb::connector_service::coordinate_request::Msg;
28use risingwave_pb::connector_service::{CoordinateRequest, CoordinateResponse, coordinate_request};
29use rw_futures_util::pending_on_none;
30use sea_orm::DatabaseConnection;
31use thiserror_ext::AsReport;
32use tokio::sync::mpsc;
33use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
34use tokio::sync::oneshot::{Receiver, Sender, channel};
35use tokio::task::{JoinError, JoinHandle};
36use tokio_stream::wrappers::UnboundedReceiverStream;
37use tonic::Status;
38use tracing::{debug, error, info, warn};
39
40use crate::hummock::HummockManagerRef;
41use crate::manager::MetadataManager;
42use crate::manager::sink_coordination::SinkWriterRequestStream;
43use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
44use crate::manager::sink_coordination::handle::SinkWriterCoordinationHandle;
45
46macro_rules! send_with_err_check {
47    ($tx:expr, $msg:expr) => {
48        if $tx.send($msg).is_err() {
49            error!("unable to send msg");
50        }
51    };
52}
53
54macro_rules! send_await_with_err_check {
55    ($tx:expr, $msg:expr) => {
56        if $tx.send($msg).await.is_err() {
57            error!("unable to send msg");
58        }
59    };
60}
61
62const BOUNDED_CHANNEL_SIZE: usize = 16;
63
64enum ManagerRequest {
65    NewSinkWriter(SinkWriterCoordinationHandle),
66    StopCoordinator {
67        finish_notifier: Sender<()>,
68        /// sink id to stop. When `None`, stop all sink coordinator
69        sink_id: Option<SinkId>,
70    },
71}
72
73#[derive(Clone)]
74pub struct SinkCoordinatorManager {
75    request_tx: mpsc::Sender<ManagerRequest>,
76}
77
78fn new_committed_epoch_subscriber(
79    hummock_manager: HummockManagerRef,
80    metadata_manager: MetadataManager,
81) -> SinkCommittedEpochSubscriber {
82    Arc::new(move |sink_id| {
83        let hummock_manager = hummock_manager.clone();
84        let metadata_manager = metadata_manager.clone();
85        async move {
86            let state_table_ids = metadata_manager
87                .get_sink_state_table_ids(sink_id.sink_id as _)
88                .await
89                .map_err(SinkError::from)?;
90            let Some(table_id) = state_table_ids.first() else {
91                return Err(anyhow!("no state table id in sink: {}", sink_id).into());
92            };
93            hummock_manager
94                .subscribe_table_committed_epoch(*table_id)
95                .await
96                .map_err(SinkError::from)
97        }
98        .boxed()
99    })
100}
101
102impl SinkCoordinatorManager {
103    pub fn start_worker(
104        db: DatabaseConnection,
105        hummock_manager: HummockManagerRef,
106        metadata_manager: MetadataManager,
107        iceberg_compact_stat_sender: UnboundedSender<IcebergSinkCompactionUpdate>,
108    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
109        let subscriber =
110            new_committed_epoch_subscriber(hummock_manager.clone(), metadata_manager.clone());
111        Self::start_worker_with_spawn_worker(move |param, manager_request_stream| {
112            tokio::spawn(CoordinatorWorker::run(
113                param,
114                manager_request_stream,
115                db.clone(),
116                subscriber.clone(),
117                iceberg_compact_stat_sender.clone(),
118            ))
119        })
120    }
121
122    fn start_worker_with_spawn_worker(
123        spawn_coordinator_worker: impl SpawnCoordinatorFn,
124    ) -> (Self, (JoinHandle<()>, Sender<()>)) {
125        let (request_tx, request_rx) = mpsc::channel(BOUNDED_CHANNEL_SIZE);
126        let (shutdown_tx, shutdown_rx) = channel();
127        let worker = ManagerWorker::new(request_rx, shutdown_rx);
128        let join_handle = tokio::spawn(worker.execute(spawn_coordinator_worker));
129        (
130            SinkCoordinatorManager { request_tx },
131            (join_handle, shutdown_tx),
132        )
133    }
134
135    pub async fn handle_new_request(
136        &self,
137        mut request_stream: SinkWriterRequestStream,
138    ) -> Result<impl Stream<Item = Result<CoordinateResponse, Status>> + use<>, Status> {
139        let (param, vnode_bitmap) = match request_stream.try_next().await? {
140            Some(CoordinateRequest {
141                msg:
142                    Some(Msg::StartRequest(coordinate_request::StartCoordinationRequest {
143                        param: Some(param),
144                        vnode_bitmap: Some(vnode_bitmap),
145                    })),
146            }) => (SinkParam::from_proto(param), Bitmap::from(&vnode_bitmap)),
147            msg => {
148                return Err(Status::invalid_argument(format!(
149                    "expected CoordinateRequest::StartRequest in the first request, get {:?}",
150                    msg
151                )));
152            }
153        };
154        let (response_tx, response_rx) = mpsc::unbounded_channel();
155        self.request_tx
156            .send(ManagerRequest::NewSinkWriter(
157                SinkWriterCoordinationHandle::new(request_stream, response_tx, param, vnode_bitmap),
158            ))
159            .await
160            .map_err(|_| {
161                Status::unavailable(
162                    "unable to send to sink manager worker. The worker may have stopped",
163                )
164            })?;
165
166        Ok(UnboundedReceiverStream::new(response_rx))
167    }
168
169    async fn stop_coordinator(&self, sink_id: Option<SinkId>) {
170        let (tx, rx) = channel();
171        send_await_with_err_check!(
172            self.request_tx,
173            ManagerRequest::StopCoordinator {
174                finish_notifier: tx,
175                sink_id,
176            }
177        );
178        if rx.await.is_err() {
179            error!("fail to wait for resetting sink manager worker");
180        }
181        info!("successfully stop coordinator: {:?}", sink_id);
182    }
183
184    pub async fn reset(&self) {
185        self.stop_coordinator(None).await;
186    }
187
188    pub async fn stop_sink_coordinator(&self, sink_id: SinkId) {
189        self.stop_coordinator(Some(sink_id)).await;
190    }
191}
192
193struct CoordinatorWorkerHandle {
194    /// Sender to coordinator worker. Drop the sender as a stop signal
195    request_sender: Option<UnboundedSender<SinkWriterCoordinationHandle>>,
196    /// Notify when the coordinator worker stops
197    finish_notifiers: Vec<Sender<()>>,
198}
199
200struct ManagerWorker {
201    request_rx: mpsc::Receiver<ManagerRequest>,
202    // Make it option so that it can be polled with &mut SinkManagerWorker
203    shutdown_rx: Receiver<()>,
204
205    running_coordinator_worker_join_handles:
206        FuturesUnordered<BoxFuture<'static, (SinkId, Result<(), JoinError>)>>,
207    running_coordinator_worker: HashMap<SinkId, CoordinatorWorkerHandle>,
208}
209
210enum ManagerEvent {
211    NewRequest(ManagerRequest),
212    CoordinatorWorkerFinished {
213        sink_id: SinkId,
214        join_result: Result<(), JoinError>,
215    },
216}
217
218trait SpawnCoordinatorFn = FnMut(SinkParam, UnboundedReceiver<SinkWriterCoordinationHandle>) -> JoinHandle<()>
219    + Send
220    + 'static;
221
222impl ManagerWorker {
223    fn new(request_rx: mpsc::Receiver<ManagerRequest>, shutdown_rx: Receiver<()>) -> Self {
224        ManagerWorker {
225            request_rx,
226            shutdown_rx,
227            running_coordinator_worker_join_handles: Default::default(),
228            running_coordinator_worker: Default::default(),
229        }
230    }
231
232    async fn execute(mut self, mut spawn_coordinator_worker: impl SpawnCoordinatorFn) {
233        while let Some(event) = self.next_event().await {
234            match event {
235                ManagerEvent::NewRequest(request) => match request {
236                    ManagerRequest::NewSinkWriter(request) => {
237                        self.handle_new_sink_writer(request, &mut spawn_coordinator_worker)
238                    }
239                    ManagerRequest::StopCoordinator {
240                        finish_notifier,
241                        sink_id,
242                    } => {
243                        if let Some(sink_id) = sink_id {
244                            if let Some(worker_handle) =
245                                self.running_coordinator_worker.get_mut(&sink_id)
246                            {
247                                if let Some(sender) = worker_handle.request_sender.take() {
248                                    // drop the sender as a signal to notify the coordinator worker
249                                    // to stop
250                                    drop(sender);
251                                }
252                                worker_handle.finish_notifiers.push(finish_notifier);
253                            } else {
254                                debug!(
255                                    "sink coordinator of {} is not running. Notify finish directly",
256                                    sink_id.sink_id
257                                );
258                                send_with_err_check!(finish_notifier, ());
259                            }
260                        } else {
261                            self.clean_up().await;
262                            send_with_err_check!(finish_notifier, ());
263                        }
264                    }
265                },
266                ManagerEvent::CoordinatorWorkerFinished {
267                    sink_id,
268                    join_result,
269                } => self.handle_coordinator_finished(sink_id, join_result),
270            }
271        }
272        self.clean_up().await;
273        info!("sink manager worker exited");
274    }
275
276    async fn next_event(&mut self) -> Option<ManagerEvent> {
277        match select(
278            select(
279                pin!(self.request_rx.recv()),
280                pin!(pending_on_none(
281                    self.running_coordinator_worker_join_handles.next()
282                )),
283            ),
284            &mut self.shutdown_rx,
285        )
286        .await
287        {
288            Either::Left((either, _)) => match either {
289                Either::Left((Some(request), _)) => Some(ManagerEvent::NewRequest(request)),
290                Either::Left((None, _)) => None,
291                Either::Right(((sink_id, join_result), _)) => {
292                    Some(ManagerEvent::CoordinatorWorkerFinished {
293                        sink_id,
294                        join_result,
295                    })
296                }
297            },
298            Either::Right(_) => None,
299        }
300    }
301
302    async fn clean_up(&mut self) {
303        info!("sink manager worker start cleaning up");
304        for worker_handle in self.running_coordinator_worker.values_mut() {
305            if let Some(sender) = worker_handle.request_sender.take() {
306                // drop the sender to notify the coordinator worker to stop
307                drop(sender);
308            }
309        }
310        while let Some((sink_id, join_result)) =
311            self.running_coordinator_worker_join_handles.next().await
312        {
313            self.handle_coordinator_finished(sink_id, join_result);
314        }
315        info!("sink manager worker finished cleaning up");
316    }
317
318    fn handle_coordinator_finished(&mut self, sink_id: SinkId, join_result: Result<(), JoinError>) {
319        let worker_handle = self
320            .running_coordinator_worker
321            .remove(&sink_id)
322            .expect("finished coordinator should have an associated worker handle");
323        for finish_notifier in worker_handle.finish_notifiers {
324            send_with_err_check!(finish_notifier, ());
325        }
326        match join_result {
327            Ok(()) => {
328                info!(
329                    id = sink_id.sink_id,
330                    "sink coordinator has gracefully finished",
331                );
332            }
333            Err(err) => {
334                error!(
335                    id = sink_id.sink_id,
336                    error = %err.as_report(),
337                    "sink coordinator finished with error",
338                );
339            }
340        }
341    }
342
343    fn handle_new_sink_writer(
344        &mut self,
345        new_writer: SinkWriterCoordinationHandle,
346        spawn_coordinator_worker: &mut impl SpawnCoordinatorFn,
347    ) {
348        let param = new_writer.param();
349        let sink_id = param.sink_id;
350
351        let handle = self
352            .running_coordinator_worker
353            .entry(param.sink_id)
354            .or_insert_with(|| {
355                // Launch the coordinator worker task if it is the first
356                let (request_tx, request_rx) = unbounded_channel();
357                let join_handle = spawn_coordinator_worker(param.clone(), request_rx);
358                self.running_coordinator_worker_join_handles.push(
359                    join_handle
360                        .map(move |join_result| (sink_id, join_result))
361                        .boxed(),
362                );
363                CoordinatorWorkerHandle {
364                    request_sender: Some(request_tx),
365                    finish_notifiers: Vec::new(),
366                }
367            });
368
369        if let Some(sender) = handle.request_sender.as_mut() {
370            send_with_err_check!(sender, new_writer);
371        } else {
372            warn!(
373                "handle a new request while the sink coordinator is being stopped: {:?}",
374                param
375            );
376            new_writer.abort(Status::internal("the sink is being stopped"));
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use std::future::{Future, poll_fn};
384    use std::pin::pin;
385    use std::sync::Arc;
386    use std::task::Poll;
387
388    use anyhow::anyhow;
389    use async_trait::async_trait;
390    use futures::future::{join, try_join};
391    use futures::{FutureExt, StreamExt, TryFutureExt};
392    use itertools::Itertools;
393    use rand::seq::SliceRandom;
394    use risingwave_common::bitmap::BitmapBuilder;
395    use risingwave_common::hash::VirtualNode;
396    use risingwave_connector::sink::catalog::{SinkId, SinkType};
397    use risingwave_connector::sink::{SinkCommitCoordinator, SinkError, SinkParam};
398    use risingwave_pb::connector_service::SinkMetadata;
399    use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata};
400    use risingwave_rpc_client::CoordinatorStreamHandle;
401    use tokio::sync::mpsc::unbounded_channel;
402    use tokio_stream::wrappers::ReceiverStream;
403
404    use crate::manager::sink_coordination::SinkCoordinatorManager;
405    use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker;
406    use crate::manager::sink_coordination::manager::SinkCommittedEpochSubscriber;
407
408    struct MockCoordinator<C, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>> {
409        context: C,
410        f: F,
411    }
412
413    impl<C, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError>> MockCoordinator<C, F> {
414        fn new(context: C, f: F) -> Self {
415            MockCoordinator { context, f }
416        }
417    }
418
419    #[async_trait]
420    impl<C: Send, F: FnMut(u64, Vec<SinkMetadata>, &mut C) -> Result<(), SinkError> + Send>
421        SinkCommitCoordinator for MockCoordinator<C, F>
422    {
423        async fn init(
424            &mut self,
425            _subscriber: SinkCommittedEpochSubscriber,
426        ) -> risingwave_connector::sink::Result<Option<u64>> {
427            Ok(None)
428        }
429
430        async fn commit(
431            &mut self,
432            epoch: u64,
433            metadata: Vec<SinkMetadata>,
434        ) -> risingwave_connector::sink::Result<()> {
435            (self.f)(epoch, metadata, &mut self.context)
436        }
437    }
438
439    #[tokio::test]
440    async fn test_basic() {
441        let param = SinkParam {
442            sink_id: SinkId::from(1),
443            sink_name: "test".into(),
444            properties: Default::default(),
445            columns: vec![],
446            downstream_pk: vec![],
447            sink_type: SinkType::AppendOnly,
448            format_desc: None,
449            db_name: "test".into(),
450            sink_from_name: "test".into(),
451        };
452
453        let epoch0 = 232;
454        let epoch1 = 233;
455        let epoch2 = 234;
456
457        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
458        all_vnode.shuffle(&mut rand::rng());
459        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
460        let build_bitmap = |indexes: &[usize]| {
461            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
462            for i in indexes {
463                builder.set(*i, true);
464            }
465            builder.finish()
466        };
467        let vnode1 = build_bitmap(first);
468        let vnode2 = build_bitmap(second);
469
470        let metadata = [
471            [vec![1u8, 2u8], vec![3u8, 4u8]],
472            [vec![5u8, 6u8], vec![7u8, 8u8]],
473        ];
474        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
475            let (_sender, receiver) = unbounded_channel();
476
477            async move { Ok((1, receiver)) }.boxed()
478        });
479
480        let (manager, (_join_handle, _stop_tx)) =
481            SinkCoordinatorManager::start_worker_with_spawn_worker({
482                let expected_param = param.clone();
483                let metadata = metadata.clone();
484                move |param, new_writer_rx| {
485                    let metadata = metadata.clone();
486                    let expected_param = expected_param.clone();
487                    tokio::spawn({
488                        let subscriber = mock_subscriber.clone();
489                        async move {
490                            // validate the start request
491                            assert_eq!(param, expected_param);
492                            CoordinatorWorker::execute_coordinator(
493                                param.clone(),
494                                new_writer_rx,
495                                MockCoordinator::new(
496                                    0,
497                                    |epoch, metadata_list, count: &mut usize| {
498                                        *count += 1;
499                                        let mut metadata_list =
500                                            metadata_list
501                                                .into_iter()
502                                                .map(|metadata| match metadata {
503                                                    SinkMetadata {
504                                                        metadata:
505                                                            Some(Metadata::Serialized(
506                                                                SerializedMetadata { metadata },
507                                                            )),
508                                                    } => metadata,
509                                                    _ => unreachable!(),
510                                                })
511                                                .collect_vec();
512                                        metadata_list.sort();
513                                        match *count {
514                                            1 => {
515                                                assert_eq!(epoch, epoch1);
516                                                assert_eq!(2, metadata_list.len());
517                                                assert_eq!(metadata[0][0], metadata_list[0]);
518                                                assert_eq!(metadata[0][1], metadata_list[1]);
519                                            }
520                                            2 => {
521                                                assert_eq!(epoch, epoch2);
522                                                assert_eq!(2, metadata_list.len());
523                                                assert_eq!(metadata[1][0], metadata_list[0]);
524                                                assert_eq!(metadata[1][1], metadata_list[1]);
525                                            }
526                                            _ => unreachable!(),
527                                        }
528                                        Ok(())
529                                    },
530                                ),
531                                subscriber.clone(),
532                            )
533                            .await;
534                        }
535                    })
536                }
537            });
538
539        let build_client = |vnode| async {
540            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
541                Ok(tonic::Response::new(
542                    manager
543                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
544                        .await
545                        .unwrap()
546                        .boxed(),
547                ))
548            })
549            .await
550            .unwrap()
551            .0
552        };
553
554        let (mut client1, mut client2) =
555            join(build_client(vnode1), pin!(build_client(vnode2))).await;
556
557        let (aligned_epoch1, aligned_epoch2) = try_join(
558            client1.align_initial_epoch(epoch0),
559            client2.align_initial_epoch(epoch1),
560        )
561        .await
562        .unwrap();
563        assert_eq!(aligned_epoch1, epoch1);
564        assert_eq!(aligned_epoch2, epoch1);
565
566        {
567            // commit epoch1
568            let mut commit_future = pin!(
569                client2
570                    .commit(
571                        epoch1,
572                        SinkMetadata {
573                            metadata: Some(Metadata::Serialized(SerializedMetadata {
574                                metadata: metadata[0][1].clone(),
575                            })),
576                        },
577                    )
578                    .map(|result| result.unwrap())
579            );
580            assert!(
581                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
582                    .await
583                    .is_pending()
584            );
585            join(
586                commit_future,
587                client1
588                    .commit(
589                        epoch1,
590                        SinkMetadata {
591                            metadata: Some(Metadata::Serialized(SerializedMetadata {
592                                metadata: metadata[0][0].clone(),
593                            })),
594                        },
595                    )
596                    .map(|result| result.unwrap()),
597            )
598            .await;
599        }
600
601        // commit epoch2
602        let mut commit_future = pin!(
603            client1
604                .commit(
605                    epoch2,
606                    SinkMetadata {
607                        metadata: Some(Metadata::Serialized(SerializedMetadata {
608                            metadata: metadata[1][0].clone(),
609                        })),
610                    },
611                )
612                .map(|result| result.unwrap())
613        );
614        assert!(
615            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
616                .await
617                .is_pending()
618        );
619        join(
620            commit_future,
621            client2
622                .commit(
623                    epoch2,
624                    SinkMetadata {
625                        metadata: Some(Metadata::Serialized(SerializedMetadata {
626                            metadata: metadata[1][1].clone(),
627                        })),
628                    },
629                )
630                .map(|result| result.unwrap()),
631        )
632        .await;
633    }
634
635    #[tokio::test]
636    async fn test_single_writer() {
637        let param = SinkParam {
638            sink_id: SinkId::from(1),
639            sink_name: "test".into(),
640            properties: Default::default(),
641            columns: vec![],
642            downstream_pk: vec![],
643            sink_type: SinkType::AppendOnly,
644            format_desc: None,
645            db_name: "test".into(),
646            sink_from_name: "test".into(),
647        };
648
649        let epoch1 = 233;
650        let epoch2 = 234;
651
652        let all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
653        let build_bitmap = |indexes: &[usize]| {
654            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
655            for i in indexes {
656                builder.set(*i, true);
657            }
658            builder.finish()
659        };
660        let vnode = build_bitmap(&all_vnode);
661
662        let metadata = [vec![1u8, 2u8], vec![3u8, 4u8]];
663        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
664            let (_sender, receiver) = unbounded_channel();
665
666            async move { Ok((1, receiver)) }.boxed()
667        });
668        let (manager, (_join_handle, _stop_tx)) =
669            SinkCoordinatorManager::start_worker_with_spawn_worker({
670                let expected_param = param.clone();
671                let metadata = metadata.clone();
672                move |param, new_writer_rx| {
673                    let metadata = metadata.clone();
674                    let expected_param = expected_param.clone();
675                    tokio::spawn({
676                        let subscriber = mock_subscriber.clone();
677                        async move {
678                            // validate the start request
679                            assert_eq!(param, expected_param);
680                            CoordinatorWorker::execute_coordinator(
681                                param.clone(),
682                                new_writer_rx,
683                                MockCoordinator::new(
684                                    0,
685                                    |epoch, metadata_list, count: &mut usize| {
686                                        *count += 1;
687                                        let mut metadata_list =
688                                            metadata_list
689                                                .into_iter()
690                                                .map(|metadata| match metadata {
691                                                    SinkMetadata {
692                                                        metadata:
693                                                            Some(Metadata::Serialized(
694                                                                SerializedMetadata { metadata },
695                                                            )),
696                                                    } => metadata,
697                                                    _ => unreachable!(),
698                                                })
699                                                .collect_vec();
700                                        metadata_list.sort();
701                                        match *count {
702                                            1 => {
703                                                assert_eq!(epoch, epoch1);
704                                                assert_eq!(1, metadata_list.len());
705                                                assert_eq!(metadata[0], metadata_list[0]);
706                                            }
707                                            2 => {
708                                                assert_eq!(epoch, epoch2);
709                                                assert_eq!(1, metadata_list.len());
710                                                assert_eq!(metadata[1], metadata_list[0]);
711                                            }
712                                            _ => unreachable!(),
713                                        }
714                                        Ok(())
715                                    },
716                                ),
717                                subscriber.clone(),
718                            )
719                            .await;
720                        }
721                    })
722                }
723            });
724
725        let build_client = |vnode| async {
726            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
727                Ok(tonic::Response::new(
728                    manager
729                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
730                        .await
731                        .unwrap()
732                        .boxed(),
733                ))
734            })
735            .await
736            .unwrap()
737            .0
738        };
739
740        let mut client = build_client(vnode).await;
741
742        let aligned_epoch = client.align_initial_epoch(epoch1).await.unwrap();
743        assert_eq!(aligned_epoch, epoch1);
744
745        client
746            .commit(
747                epoch1,
748                SinkMetadata {
749                    metadata: Some(Metadata::Serialized(SerializedMetadata {
750                        metadata: metadata[0].clone(),
751                    })),
752                },
753            )
754            .await
755            .unwrap();
756
757        client
758            .commit(
759                epoch2,
760                SinkMetadata {
761                    metadata: Some(Metadata::Serialized(SerializedMetadata {
762                        metadata: metadata[1].clone(),
763                    })),
764                },
765            )
766            .await
767            .unwrap();
768    }
769
770    #[tokio::test]
771    async fn test_partial_commit() {
772        let param = SinkParam {
773            sink_id: SinkId::from(1),
774            sink_name: "test".into(),
775            properties: Default::default(),
776            columns: vec![],
777            downstream_pk: vec![],
778            sink_type: SinkType::AppendOnly,
779            format_desc: None,
780            db_name: "test".into(),
781            sink_from_name: "test".into(),
782        };
783
784        let epoch = 233;
785
786        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
787        all_vnode.shuffle(&mut rand::rng());
788        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
789        let build_bitmap = |indexes: &[usize]| {
790            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
791            for i in indexes {
792                builder.set(*i, true);
793            }
794            builder.finish()
795        };
796        let vnode1 = build_bitmap(first);
797        let vnode2 = build_bitmap(second);
798
799        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
800            let (_sender, receiver) = unbounded_channel();
801
802            async move { Ok((1, receiver)) }.boxed()
803        });
804        let (manager, (_join_handle, _stop_tx)) =
805            SinkCoordinatorManager::start_worker_with_spawn_worker({
806                let expected_param = param.clone();
807                move |param, new_writer_rx| {
808                    let expected_param = expected_param.clone();
809                    tokio::spawn({
810                        let subscriber = mock_subscriber.clone();
811                        async move {
812                            // validate the start request
813                            assert_eq!(param, expected_param);
814                            CoordinatorWorker::execute_coordinator(
815                                param,
816                                new_writer_rx,
817                                MockCoordinator::new((), |_, _, _| unreachable!()),
818                                subscriber.clone(),
819                            )
820                            .await;
821                        }
822                    })
823                }
824            });
825
826        let build_client = |vnode| async {
827            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
828                Ok(tonic::Response::new(
829                    manager
830                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
831                        .await
832                        .unwrap()
833                        .boxed(),
834                ))
835            })
836            .await
837            .unwrap()
838            .0
839        };
840
841        let (mut client1, client2) = join(build_client(vnode1), build_client(vnode2)).await;
842
843        // commit epoch
844        let mut commit_future = pin!(client1.commit(
845            epoch,
846            SinkMetadata {
847                metadata: Some(Metadata::Serialized(SerializedMetadata {
848                    metadata: vec![],
849                })),
850            },
851        ));
852        assert!(
853            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
854                .await
855                .is_pending()
856        );
857        drop(client2);
858        assert!(commit_future.await.is_err());
859    }
860
861    #[tokio::test]
862    async fn test_fail_commit() {
863        let param = SinkParam {
864            sink_id: SinkId::from(1),
865            sink_name: "test".into(),
866            properties: Default::default(),
867            columns: vec![],
868            downstream_pk: vec![],
869            sink_type: SinkType::AppendOnly,
870            format_desc: None,
871            db_name: "test".into(),
872            sink_from_name: "test".into(),
873        };
874
875        let epoch = 233;
876
877        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
878        all_vnode.shuffle(&mut rand::rng());
879        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
880        let build_bitmap = |indexes: &[usize]| {
881            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
882            for i in indexes {
883                builder.set(*i, true);
884            }
885            builder.finish()
886        };
887        let vnode1 = build_bitmap(first);
888        let vnode2 = build_bitmap(second);
889        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
890            let (_sender, receiver) = unbounded_channel();
891
892            async move { Ok((1, receiver)) }.boxed()
893        });
894        let (manager, (_join_handle, _stop_tx)) =
895            SinkCoordinatorManager::start_worker_with_spawn_worker({
896                let expected_param = param.clone();
897                move |param, new_writer_rx| {
898                    let expected_param = expected_param.clone();
899                    tokio::spawn({
900                        let subscriber = mock_subscriber.clone();
901                        {
902                            async move {
903                                // validate the start request
904                                assert_eq!(param, expected_param);
905                                CoordinatorWorker::execute_coordinator(
906                                    param,
907                                    new_writer_rx,
908                                    MockCoordinator::new((), |_, _, _| {
909                                        Err(SinkError::Coordinator(anyhow!("failed to commit")))
910                                    }),
911                                    subscriber.clone(),
912                                )
913                                .await;
914                            }
915                        }
916                    })
917                }
918            });
919
920        let build_client = |vnode| async {
921            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
922                Ok(tonic::Response::new(
923                    manager
924                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
925                        .await
926                        .unwrap()
927                        .boxed(),
928                ))
929            })
930            .await
931            .unwrap()
932            .0
933        };
934
935        let (mut client1, mut client2) = join(build_client(vnode1), build_client(vnode2)).await;
936
937        // commit epoch
938        let mut commit_future = pin!(client1.commit(
939            epoch,
940            SinkMetadata {
941                metadata: Some(Metadata::Serialized(SerializedMetadata {
942                    metadata: vec![],
943                })),
944            },
945        ));
946        assert!(
947            poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
948                .await
949                .is_pending()
950        );
951        let (result1, result2) = join(
952            commit_future,
953            client2.commit(
954                epoch,
955                SinkMetadata {
956                    metadata: Some(Metadata::Serialized(SerializedMetadata {
957                        metadata: vec![],
958                    })),
959                },
960            ),
961        )
962        .await;
963        assert!(result1.is_err());
964        assert!(result2.is_err());
965    }
966
967    #[tokio::test]
968    async fn test_update_vnode_bitmap() {
969        let param = SinkParam {
970            sink_id: SinkId::from(1),
971            sink_name: "test".into(),
972            properties: Default::default(),
973            columns: vec![],
974            downstream_pk: vec![],
975            sink_type: SinkType::AppendOnly,
976            format_desc: None,
977            db_name: "test".into(),
978            sink_from_name: "test".into(),
979        };
980
981        let epoch1 = 233;
982        let epoch2 = 234;
983        let epoch3 = 235;
984        let epoch4 = 236;
985
986        let mut all_vnode = (0..VirtualNode::COUNT_FOR_TEST).collect_vec();
987        all_vnode.shuffle(&mut rand::rng());
988        let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 2);
989        let build_bitmap = |indexes: &[usize]| {
990            let mut builder = BitmapBuilder::zeroed(VirtualNode::COUNT_FOR_TEST);
991            for i in indexes {
992                builder.set(*i, true);
993            }
994            builder.finish()
995        };
996        let vnode1 = build_bitmap(first);
997        let vnode2 = build_bitmap(second);
998
999        let metadata = [
1000            [vec![1u8, 2u8], vec![3u8, 4u8]],
1001            [vec![5u8, 6u8], vec![7u8, 8u8]],
1002        ];
1003
1004        let metadata_scale_out = [vec![9u8, 10u8], vec![11u8, 12u8], vec![13u8, 14u8]];
1005        let metadata_scale_in = [vec![13u8, 14u8], vec![15u8, 16u8]];
1006        let mock_subscriber: SinkCommittedEpochSubscriber = Arc::new(move |_sink_id: SinkId| {
1007            let (_sender, receiver) = unbounded_channel();
1008
1009            async move { Ok((1, receiver)) }.boxed()
1010        });
1011        let (manager, (_join_handle, _stop_tx)) =
1012            SinkCoordinatorManager::start_worker_with_spawn_worker({
1013                let expected_param = param.clone();
1014                let metadata = metadata.clone();
1015                let metadata_scale_out = metadata_scale_out.clone();
1016                let metadata_scale_in = metadata_scale_in.clone();
1017                move |param, new_writer_rx| {
1018                    let metadata = metadata.clone();
1019                    let metadata_scale_out = metadata_scale_out.clone();
1020                    let metadata_scale_in = metadata_scale_in.clone();
1021                    let expected_param = expected_param.clone();
1022                    tokio::spawn({
1023                        let subscriber = mock_subscriber.clone();
1024                        async move {
1025                            // validate the start request
1026                            assert_eq!(param, expected_param);
1027                            CoordinatorWorker::execute_coordinator(
1028                                param.clone(),
1029                                new_writer_rx,
1030                                MockCoordinator::new(
1031                                    0,
1032                                    |epoch, metadata_list, count: &mut usize| {
1033                                        *count += 1;
1034                                        let mut metadata_list =
1035                                            metadata_list
1036                                                .into_iter()
1037                                                .map(|metadata| match metadata {
1038                                                    SinkMetadata {
1039                                                        metadata:
1040                                                            Some(Metadata::Serialized(
1041                                                                SerializedMetadata { metadata },
1042                                                            )),
1043                                                    } => metadata,
1044                                                    _ => unreachable!(),
1045                                                })
1046                                                .collect_vec();
1047                                        metadata_list.sort();
1048                                        let (expected_epoch, expected_metadata_list) = match *count
1049                                        {
1050                                            1 => (epoch1, metadata[0].as_slice()),
1051                                            2 => (epoch2, metadata[1].as_slice()),
1052                                            3 => (epoch3, metadata_scale_out.as_slice()),
1053                                            4 => (epoch4, metadata_scale_in.as_slice()),
1054                                            _ => unreachable!(),
1055                                        };
1056                                        assert_eq!(expected_epoch, epoch);
1057                                        assert_eq!(expected_metadata_list, &metadata_list);
1058                                        Ok(())
1059                                    },
1060                                ),
1061                                subscriber.clone(),
1062                            )
1063                            .await;
1064                        }
1065                    })
1066                }
1067            });
1068
1069        let build_client = |vnode| async {
1070            CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async {
1071                Ok(tonic::Response::new(
1072                    manager
1073                        .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed())
1074                        .await
1075                        .unwrap()
1076                        .boxed(),
1077                ))
1078            })
1079            .await
1080        };
1081
1082        let ((mut client1, _), (mut client2, _)) =
1083            try_join(build_client(vnode1), pin!(build_client(vnode2)))
1084                .await
1085                .unwrap();
1086
1087        let (aligned_epoch1, aligned_epoch2) = try_join(
1088            client1.align_initial_epoch(epoch1),
1089            client2.align_initial_epoch(epoch1),
1090        )
1091        .await
1092        .unwrap();
1093        assert_eq!(aligned_epoch1, epoch1);
1094        assert_eq!(aligned_epoch2, epoch1);
1095
1096        {
1097            // commit epoch1
1098            let mut commit_future = pin!(
1099                client2
1100                    .commit(
1101                        epoch1,
1102                        SinkMetadata {
1103                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1104                                metadata: metadata[0][1].clone(),
1105                            })),
1106                        },
1107                    )
1108                    .map(|result| result.unwrap())
1109            );
1110            assert!(
1111                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1112                    .await
1113                    .is_pending()
1114            );
1115            join(
1116                commit_future,
1117                client1
1118                    .commit(
1119                        epoch1,
1120                        SinkMetadata {
1121                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1122                                metadata: metadata[0][0].clone(),
1123                            })),
1124                        },
1125                    )
1126                    .map(|result| result.unwrap()),
1127            )
1128            .await;
1129        }
1130
1131        let (vnode1, vnode2, vnode3) = {
1132            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1133            let (second, third) = second.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1134            (
1135                build_bitmap(first),
1136                build_bitmap(second),
1137                build_bitmap(third),
1138            )
1139        };
1140
1141        let mut build_client3_future = pin!(build_client(vnode3));
1142        assert!(
1143            poll_fn(|cx| Poll::Ready(build_client3_future.as_mut().poll(cx)))
1144                .await
1145                .is_pending()
1146        );
1147        let mut client3;
1148        {
1149            {
1150                // commit epoch2
1151                let mut commit_future = pin!(
1152                    client1
1153                        .commit(
1154                            epoch2,
1155                            SinkMetadata {
1156                                metadata: Some(Metadata::Serialized(SerializedMetadata {
1157                                    metadata: metadata[1][0].clone(),
1158                                })),
1159                            },
1160                        )
1161                        .map_err(Into::into)
1162                );
1163                assert!(
1164                    poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1165                        .await
1166                        .is_pending()
1167                );
1168                try_join(
1169                    commit_future,
1170                    client2.commit(
1171                        epoch2,
1172                        SinkMetadata {
1173                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1174                                metadata: metadata[1][1].clone(),
1175                            })),
1176                        },
1177                    ),
1178                )
1179                .await
1180                .unwrap();
1181            }
1182
1183            client3 = {
1184                let (
1185                    (client3, init_epoch),
1186                    (update_vnode_bitmap_epoch1, update_vnode_bitmap_epoch2),
1187                ) = try_join(
1188                    build_client3_future,
1189                    try_join(
1190                        client1.update_vnode_bitmap(&vnode1),
1191                        client2.update_vnode_bitmap(&vnode2),
1192                    )
1193                    .map_err(Into::into),
1194                )
1195                .await
1196                .unwrap();
1197                assert_eq!(init_epoch, Some(epoch2));
1198                assert_eq!(update_vnode_bitmap_epoch1, epoch2);
1199                assert_eq!(update_vnode_bitmap_epoch2, epoch2);
1200                client3
1201            };
1202            let mut commit_future3 = pin!(client3.commit(
1203                epoch3,
1204                SinkMetadata {
1205                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1206                        metadata: metadata_scale_out[2].clone(),
1207                    })),
1208                },
1209            ));
1210            assert!(
1211                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1212                    .await
1213                    .is_pending()
1214            );
1215            let mut commit_future1 = pin!(client1.commit(
1216                epoch3,
1217                SinkMetadata {
1218                    metadata: Some(Metadata::Serialized(SerializedMetadata {
1219                        metadata: metadata_scale_out[0].clone(),
1220                    })),
1221                },
1222            ));
1223            assert!(
1224                poll_fn(|cx| Poll::Ready(commit_future1.as_mut().poll(cx)))
1225                    .await
1226                    .is_pending()
1227            );
1228            assert!(
1229                poll_fn(|cx| Poll::Ready(commit_future3.as_mut().poll(cx)))
1230                    .await
1231                    .is_pending()
1232            );
1233            try_join(
1234                client2.commit(
1235                    epoch3,
1236                    SinkMetadata {
1237                        metadata: Some(Metadata::Serialized(SerializedMetadata {
1238                            metadata: metadata_scale_out[1].clone(),
1239                        })),
1240                    },
1241                ),
1242                try_join(commit_future1, commit_future3),
1243            )
1244            .await
1245            .unwrap();
1246        }
1247
1248        let (vnode2, vnode3) = {
1249            let (first, second) = all_vnode.split_at(VirtualNode::COUNT_FOR_TEST / 3);
1250            (build_bitmap(first), build_bitmap(second))
1251        };
1252
1253        {
1254            let (_, (update_vnode_bitmap_epoch2, update_vnode_bitmap_epoch3)) = try_join(
1255                client1.stop(),
1256                try_join(
1257                    client2.update_vnode_bitmap(&vnode2),
1258                    client3.update_vnode_bitmap(&vnode3),
1259                ),
1260            )
1261            .await
1262            .unwrap();
1263            assert_eq!(update_vnode_bitmap_epoch2, epoch3);
1264            assert_eq!(update_vnode_bitmap_epoch3, epoch3);
1265        }
1266
1267        {
1268            let mut commit_future = pin!(
1269                client2
1270                    .commit(
1271                        epoch4,
1272                        SinkMetadata {
1273                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1274                                metadata: metadata_scale_in[0].clone(),
1275                            })),
1276                        },
1277                    )
1278                    .map(|result| result.unwrap())
1279            );
1280            assert!(
1281                poll_fn(|cx| Poll::Ready(commit_future.as_mut().poll(cx)))
1282                    .await
1283                    .is_pending()
1284            );
1285            join(
1286                commit_future,
1287                client3
1288                    .commit(
1289                        epoch4,
1290                        SinkMetadata {
1291                            metadata: Some(Metadata::Serialized(SerializedMetadata {
1292                                metadata: metadata_scale_in[1].clone(),
1293                            })),
1294                        },
1295                    )
1296                    .map(|result| result.unwrap()),
1297            )
1298            .await;
1299        }
1300    }
1301}