sink_bench/
main.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14#![feature(coroutines)]
15#![feature(proc_macro_hygiene)]
16#![feature(stmt_expr_attributes)]
17#![recursion_limit = "256"]
18
19use core::str::FromStr;
20use core::sync::atomic::Ordering;
21use std::collections::{BTreeMap, HashMap};
22use std::sync::Arc;
23use std::sync::atomic::AtomicU64;
24
25use anyhow::anyhow;
26use clap::Parser;
27use futures::channel::oneshot;
28use futures::prelude::future::Either;
29use futures::prelude::stream::{BoxStream, PollNext};
30use futures::stream::select_with_strategy;
31use futures::{FutureExt, StreamExt, TryStreamExt};
32use futures_async_stream::try_stream;
33use itertools::Itertools;
34use plotters::backend::SVGBackend;
35use plotters::chart::ChartBuilder;
36use plotters::drawing::IntoDrawingArea;
37use plotters::element::{Circle, EmptyElement};
38use plotters::series::{LineSeries, PointSeries};
39use plotters::style::{IntoFont, RED, WHITE};
40use risingwave_common::bitmap::Bitmap;
41use risingwave_common::catalog::ColumnId;
42use risingwave_common::types::DataType;
43use risingwave_connector::dispatch_sink;
44use risingwave_connector::parser::{
45    EncodingProperties, ParserConfig, ProtocolProperties, SpecificParserConfig,
46};
47use risingwave_connector::sink::catalog::{
48    SinkEncode, SinkFormat, SinkFormatDesc, SinkId, SinkType,
49};
50use risingwave_connector::sink::log_store::{
51    LogReader, LogStoreReadItem, LogStoreResult, TruncateOffset,
52};
53use risingwave_connector::sink::mock_coordination_client::MockMetaClient;
54use risingwave_connector::sink::{
55    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_UPSERT, Sink, SinkError, SinkMetaClient, SinkParam,
56    SinkWriterParam, build_sink,
57};
58use risingwave_connector::source::datagen::{
59    DatagenProperties, DatagenSplitEnumerator, DatagenSplitReader,
60};
61use risingwave_connector::source::{
62    Column, SourceContext, SourceEnumeratorContext, SplitEnumerator, SplitReader,
63};
64use risingwave_stream::executor::test_utils::prelude::ColumnDesc;
65use risingwave_stream::executor::{Barrier, Message, MessageStreamItem, StreamExecutorError};
66use serde::{Deserialize, Deserializer};
67use thiserror_ext::AsReport;
68use tokio::sync::oneshot::Sender;
69use tokio::time::sleep;
70
71const CHECKPOINT_INTERVAL: u64 = 1000;
72const THROUGHPUT_METRIC_RECORD_INTERVAL: u64 = 500;
73const BENCH_TIME: u64 = 20;
74const BENCH_TEST: &str = "bench_test";
75
76pub struct MockRangeLogReader {
77    upstreams: BoxStream<'static, MessageStreamItem>,
78    current_epoch: u64,
79    chunk_id: usize,
80    throughput_metric: Option<ThroughputMetric>,
81    stop_rx: tokio::sync::mpsc::Receiver<()>,
82    result_tx: Option<Sender<ThroughputMetric>>,
83}
84
85impl LogReader for MockRangeLogReader {
86    async fn init(&mut self) -> LogStoreResult<()> {
87        self.throughput_metric.as_mut().unwrap().add_metric(0);
88        Ok(())
89    }
90
91    async fn next_item(&mut self) -> LogStoreResult<(u64, LogStoreReadItem)> {
92        tokio::select! {
93            _ = self.stop_rx.recv() => {
94                self.result_tx
95                        .take()
96                        .unwrap()
97                        .send(self.throughput_metric.take().unwrap())
98                        .map_err(|_| anyhow!("Can't send throughput_metric"))?;
99                    futures::future::pending().await
100                },
101            item = self.upstreams.next() => {
102                match item.unwrap().unwrap() {
103                    Message::Barrier(barrier) => {
104                        let prev_epoch = self.current_epoch;
105                        self.current_epoch = barrier.epoch.curr;
106                        Ok((
107                            prev_epoch,
108                            LogStoreReadItem::Barrier {
109                                is_checkpoint: true,
110                                new_vnode_bitmap: None,
111                                is_stop: false,
112                                add_columns: None,
113                            },
114                        ))
115                    }
116                    Message::Chunk(chunk) => {
117                        self.throughput_metric.as_mut().unwrap().add_metric(chunk.capacity());
118                        self.chunk_id += 1;
119                        Ok((
120                            self.current_epoch,
121                            LogStoreReadItem::StreamChunk {
122                                chunk,
123                                chunk_id: self.chunk_id,
124                            },
125                        ))
126                    }
127                    _ => Err(anyhow!("Can't assert message type".to_owned())),
128                }
129            }
130        }
131    }
132
133    fn truncate(&mut self, _offset: TruncateOffset) -> LogStoreResult<()> {
134        Ok(())
135    }
136
137    async fn rewind(&mut self) -> LogStoreResult<()> {
138        Err(anyhow!("should not call rewind"))
139    }
140
141    async fn start_from(&mut self, _start_offset: Option<u64>) -> LogStoreResult<()> {
142        Ok(())
143    }
144}
145
146impl MockRangeLogReader {
147    fn new(
148        mock_source: MockDatagenSource,
149        throughput_metric: ThroughputMetric,
150        stop_rx: tokio::sync::mpsc::Receiver<()>,
151        result_tx: Sender<ThroughputMetric>,
152    ) -> MockRangeLogReader {
153        MockRangeLogReader {
154            upstreams: mock_source.into_stream().boxed(),
155            current_epoch: 0,
156            chunk_id: 0,
157            throughput_metric: Some(throughput_metric),
158            stop_rx,
159            result_tx: Some(result_tx),
160        }
161    }
162}
163
164struct ThroughputMetric {
165    accumulate_chunk_size: Arc<AtomicU64>,
166    stop_tx: oneshot::Sender<()>,
167    vec_rx: oneshot::Receiver<Vec<u64>>,
168}
169
170impl ThroughputMetric {
171    pub fn new() -> Self {
172        let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
173        let (vec_tx, vec_rx) = oneshot::channel::<Vec<u64>>();
174        let accumulate_chunk_size = Arc::new(AtomicU64::new(0));
175        let accumulate_chunk_size_clone = accumulate_chunk_size.clone();
176        tokio::spawn(async move {
177            let mut chunk_size_list = vec![];
178            loop {
179                tokio::select! {
180                    _ = sleep(tokio::time::Duration::from_millis(
181                        THROUGHPUT_METRIC_RECORD_INTERVAL,
182                    )) => {
183                        chunk_size_list.push(accumulate_chunk_size_clone.load(Ordering::Relaxed));
184                    }
185                    _ = &mut stop_rx => {
186                        vec_tx.send(chunk_size_list).unwrap();
187                        break;
188                    }
189                }
190            }
191        });
192
193        Self {
194            accumulate_chunk_size,
195            stop_tx,
196            vec_rx,
197        }
198    }
199
200    pub fn add_metric(&mut self, chunk_size: usize) {
201        self.accumulate_chunk_size
202            .fetch_add(chunk_size as u64, Ordering::Relaxed);
203    }
204
205    pub async fn print_throughput(self) {
206        self.stop_tx.send(()).unwrap();
207        let throughput_sum_vec = self.vec_rx.await.unwrap();
208        #[allow(clippy::disallowed_methods)]
209        let throughput_vec = throughput_sum_vec
210            .iter()
211            .zip(throughput_sum_vec.iter().skip(1))
212            .map(|(current, next)| (next - current) * 1000 / THROUGHPUT_METRIC_RECORD_INTERVAL)
213            .collect_vec();
214        if throughput_vec.is_empty() {
215            println!("Throughput Sink: Don't get Throughput, please check");
216            return;
217        }
218        let avg = throughput_vec.iter().sum::<u64>() / throughput_vec.len() as u64;
219        let throughput_vec_sorted = throughput_vec.iter().sorted().collect_vec();
220        let p90 = throughput_vec_sorted[throughput_vec_sorted.len() * 90 / 100];
221        let p95 = throughput_vec_sorted[throughput_vec_sorted.len() * 95 / 100];
222        let p99 = throughput_vec_sorted[throughput_vec_sorted.len() * 99 / 100];
223        println!("Throughput Sink:");
224        println!("avg: {:?} rows/s ", avg);
225        println!("p90: {:?} rows/s ", p90);
226        println!("p95: {:?} rows/s ", p95);
227        println!("p99: {:?} rows/s ", p99);
228        let draw_vec: Vec<(f32, f32)> = throughput_vec
229            .iter()
230            .enumerate()
231            .map(|(index, &value)| {
232                (
233                    (index as f32) * (THROUGHPUT_METRIC_RECORD_INTERVAL as f32 / 1000_f32),
234                    value as f32,
235                )
236            })
237            .collect();
238
239        let root = SVGBackend::new("throughput.svg", (640, 480)).into_drawing_area();
240        root.fill(&WHITE).unwrap();
241        let root = root.margin(10, 10, 10, 10);
242        let mut chart = ChartBuilder::on(&root)
243            .caption("Throughput Sink", ("sans-serif", 40).into_font())
244            .x_label_area_size(20)
245            .y_label_area_size(40)
246            .build_cartesian_2d(
247                0.0..BENCH_TIME as f32,
248                **throughput_vec_sorted.first().unwrap() as f32
249                    ..**throughput_vec_sorted.last().unwrap() as f32,
250            )
251            .unwrap();
252
253        chart
254            .configure_mesh()
255            .x_labels(5)
256            .y_labels(5)
257            .y_label_formatter(&|x| format!("{:.0}", x))
258            .draw()
259            .unwrap();
260
261        chart
262            .draw_series(LineSeries::new(draw_vec.clone(), &RED))
263            .unwrap();
264        chart
265            .draw_series(PointSeries::of_element(draw_vec, 5, &RED, &|c, s, st| {
266                EmptyElement::at(c) + Circle::new((0, 0), s, st.filled())
267            }))
268            .unwrap();
269        root.present().unwrap();
270
271        println!(
272            "Throughput Sink: {:?}",
273            throughput_vec
274                .iter()
275                .map(|a| format!("{} rows/s", a))
276                .collect_vec()
277        );
278    }
279}
280
281pub struct MockDatagenSource {
282    datagen_split_readers: Vec<DatagenSplitReader>,
283}
284impl MockDatagenSource {
285    pub async fn new(
286        rows_per_second: u64,
287        source_schema: Vec<Column>,
288        split_num: String,
289    ) -> MockDatagenSource {
290        let properties = DatagenProperties {
291            split_num: Some(split_num),
292            rows_per_second,
293            fields: HashMap::default(),
294        };
295        let mut datagen_enumerator = DatagenSplitEnumerator::new(
296            properties.clone(),
297            SourceEnumeratorContext::dummy().into(),
298        )
299        .await
300        .unwrap();
301        let parser_config = ParserConfig {
302            specific: SpecificParserConfig {
303                encoding_config: EncodingProperties::Native,
304                protocol_config: ProtocolProperties::Native,
305            },
306            ..Default::default()
307        };
308        let mut datagen_split_readers = vec![];
309        let mut datagen_splits = datagen_enumerator.list_splits().await.unwrap();
310        while let Some(splits) = datagen_splits.pop() {
311            datagen_split_readers.push(
312                DatagenSplitReader::new(
313                    properties.clone(),
314                    vec![splits],
315                    parser_config.clone(),
316                    SourceContext::dummy().into(),
317                    Some(source_schema.clone()),
318                )
319                .await
320                .unwrap(),
321            );
322        }
323        MockDatagenSource {
324            datagen_split_readers,
325        }
326    }
327
328    #[try_stream(ok = Message, error = StreamExecutorError)]
329    pub async fn source_to_data_stream(mut self) {
330        let mut readers = vec![];
331        while let Some(reader) = self.datagen_split_readers.pop() {
332            readers.push(reader.into_stream());
333        }
334        loop {
335            for i in &mut readers {
336                let item = i.next().await.unwrap().unwrap();
337                yield Message::Chunk(item);
338            }
339        }
340    }
341
342    #[try_stream(ok = Message, error = StreamExecutorError)]
343    pub async fn into_stream(self) {
344        let stream = select_with_strategy(
345            Self::barrier_to_message_stream().map_ok(Either::Left),
346            self.source_to_data_stream().map_ok(Either::Right),
347            |_: &mut PollNext| PollNext::Left,
348        );
349        #[for_await]
350        for message in stream {
351            match message.unwrap() {
352                Either::Left(Message::Barrier(barrier)) => {
353                    yield Message::Barrier(barrier);
354                }
355                Either::Right(Message::Chunk(chunk)) => yield Message::Chunk(chunk),
356                _ => {
357                    return Err(StreamExecutorError::from(
358                        "Can't assert message type".to_owned(),
359                    ));
360                }
361            }
362        }
363    }
364
365    #[try_stream(ok = Message, error = StreamExecutorError)]
366    pub async fn barrier_to_message_stream() {
367        let mut epoch = 0_u64;
368        loop {
369            let prev_epoch = epoch;
370            epoch += 1;
371            let barrier = Barrier::with_prev_epoch_for_test(epoch, prev_epoch);
372            yield Message::Barrier(barrier);
373            sleep(tokio::time::Duration::from_millis(CHECKPOINT_INTERVAL)).await;
374        }
375    }
376}
377
378async fn consume_log_stream<S: Sink>(
379    sink: S,
380    mut log_reader: MockRangeLogReader,
381    mut sink_writer_param: SinkWriterParam,
382) -> Result<(), String> {
383    if let Ok(coordinator) = sink.new_coordinator(None).await {
384        sink_writer_param.meta_client = Some(SinkMetaClient::MockMetaClient(MockMetaClient::new(
385            coordinator,
386        )));
387        sink_writer_param.vnode_bitmap = Some(Bitmap::ones(1));
388    }
389    let log_sinker = sink.new_log_sinker(sink_writer_param).await.unwrap();
390    match log_sinker.consume_log_and_sink(&mut log_reader).await {
391        Ok(_) => Err("Stream closed".to_owned()),
392        Err(e) => Err(e.to_report_string()),
393    }
394}
395
396#[derive(Debug, Deserialize)]
397#[allow(dead_code)]
398struct TableSchemaFromYml {
399    table_name: String,
400    pk_indices: Option<Vec<usize>>,
401    columns: Vec<ColumnDescFromYml>,
402}
403
404impl TableSchemaFromYml {
405    pub fn get_source_schema(&self) -> Vec<Column> {
406        self.columns
407            .iter()
408            .map(|column| Column {
409                name: column.name.clone(),
410                data_type: column.r#type.clone(),
411                is_visible: true,
412            })
413            .collect()
414    }
415
416    pub fn get_sink_schema(&self) -> Vec<ColumnDesc> {
417        self.columns
418            .iter()
419            .map(|column| {
420                ColumnDesc::named(column.name.clone(), ColumnId::new(1), column.r#type.clone())
421            })
422            .collect()
423    }
424}
425#[derive(Debug, Deserialize)]
426struct ColumnDescFromYml {
427    name: String,
428    #[serde(deserialize_with = "deserialize_datatype")]
429    r#type: DataType,
430}
431
432fn deserialize_datatype<'de, D>(deserializer: D) -> Result<DataType, D::Error>
433where
434    D: Deserializer<'de>,
435{
436    let s: &str = Deserialize::deserialize(deserializer)?;
437    DataType::from_str(s).map_err(serde::de::Error::custom)
438}
439
440fn read_table_schema_from_yml(path: &str) -> TableSchemaFromYml {
441    let data = std::fs::read_to_string(path).unwrap();
442    let table: TableSchemaFromYml = serde_yaml::from_str(&data).unwrap();
443    table
444}
445
446fn read_sink_option_from_yml(path: &str) -> HashMap<String, BTreeMap<String, String>> {
447    let data = std::fs::read_to_string(path).unwrap();
448    let sink_option: HashMap<String, BTreeMap<String, String>> =
449        serde_yaml::from_str(&data).unwrap();
450    sink_option
451}
452
453#[derive(Parser, Debug)]
454pub struct Config {
455    #[clap(long, default_value = "./sink_bench/schema.yml")]
456    schema_path: String,
457
458    #[clap(short, long, default_value = "./sink_bench/sink_option.yml")]
459    option_path: String,
460
461    #[clap(short, long, default_value = BENCH_TEST)]
462    sink: String,
463
464    #[clap(short, long)]
465    rows_per_second: u64,
466
467    #[clap(long, default_value = "10")]
468    split_num: String,
469}
470
471fn mock_from_legacy_type(
472    connector: &str,
473    r#type: &str,
474) -> Result<Option<SinkFormatDesc>, SinkError> {
475    use risingwave_connector::sink::Sink as _;
476    use risingwave_connector::sink::redis::RedisSink;
477    if connector.eq(RedisSink::SINK_NAME) {
478        let format = match r#type {
479            SINK_TYPE_APPEND_ONLY => SinkFormat::AppendOnly,
480            SINK_TYPE_UPSERT => SinkFormat::Upsert,
481            _ => {
482                return Err(SinkError::Config(anyhow!(
483                    "sink type unsupported: {}",
484                    r#type
485                )));
486            }
487        };
488        Ok(Some(SinkFormatDesc {
489            format,
490            encode: SinkEncode::Json,
491            options: Default::default(),
492            secret_refs: Default::default(),
493            key_encode: None,
494            connection_id: None,
495        }))
496    } else {
497        SinkFormatDesc::from_legacy_type(connector, r#type)
498    }
499}
500
501#[tokio::main]
502async fn main() {
503    let cfg = Config::parse();
504    let table_schema = read_table_schema_from_yml(&cfg.schema_path);
505    let mock_datagen_source = MockDatagenSource::new(
506        cfg.rows_per_second,
507        table_schema.get_source_schema(),
508        cfg.split_num,
509    )
510    .await;
511    let (data_size_tx, data_size_rx) = tokio::sync::oneshot::channel::<ThroughputMetric>();
512    let (stop_tx, stop_rx) = tokio::sync::mpsc::channel::<()>(5);
513    let throughput_metric = ThroughputMetric::new();
514
515    let mut mock_range_log_reader = MockRangeLogReader::new(
516        mock_datagen_source,
517        throughput_metric,
518        stop_rx,
519        data_size_tx,
520    );
521    if cfg.sink.eq(&BENCH_TEST.to_owned()) {
522        println!("Start Sink Bench!, Wait {:?}s", BENCH_TIME);
523        tokio::spawn(async move {
524            mock_range_log_reader.init().await.unwrap();
525            loop {
526                mock_range_log_reader.next_item().await.unwrap();
527            }
528        });
529    } else {
530        let properties = read_sink_option_from_yml(&cfg.option_path)
531            .get(&cfg.sink)
532            .expect("Sink type error")
533            .clone();
534
535        let connector = properties.get("connector").unwrap().clone();
536        let format_desc = mock_from_legacy_type(
537            &connector,
538            properties.get("type").unwrap_or(&"append-only".to_owned()),
539        )
540        .unwrap();
541        let sink_param = SinkParam {
542            sink_id: SinkId::new(1),
543            sink_name: cfg.sink.clone(),
544            properties,
545            columns: table_schema.get_sink_schema(),
546            downstream_pk: table_schema.pk_indices,
547            sink_type: SinkType::AppendOnly,
548            format_desc,
549            db_name: "not_need_set".to_owned(),
550            sink_from_name: "not_need_set".to_owned(),
551        };
552        let sink = build_sink(sink_param).unwrap();
553        let sink_writer_param = SinkWriterParam::for_test();
554        println!("Start Sink Bench!, Wait {:?}s", BENCH_TIME);
555        tokio::spawn(async move {
556            dispatch_sink!(sink, sink, {
557                consume_log_stream(*sink, mock_range_log_reader, sink_writer_param).boxed()
558            })
559            .await
560            .unwrap();
561        });
562    }
563    sleep(tokio::time::Duration::from_secs(BENCH_TIME)).await;
564    println!("Bench Over!");
565    stop_tx.send(()).await.unwrap();
566    data_size_rx.await.unwrap().print_throughput().await;
567}