sink_bench/
main.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14#![feature(coroutines)]
15#![feature(proc_macro_hygiene)]
16#![feature(stmt_expr_attributes)]
17#![feature(let_chains)]
18#![recursion_limit = "256"]
19
20use core::str::FromStr;
21use core::sync::atomic::Ordering;
22use std::collections::{BTreeMap, HashMap};
23use std::sync::Arc;
24use std::sync::atomic::AtomicU64;
25
26use anyhow::anyhow;
27use clap::Parser;
28use futures::channel::oneshot;
29use futures::prelude::future::Either;
30use futures::prelude::stream::{BoxStream, PollNext};
31use futures::stream::select_with_strategy;
32use futures::{FutureExt, StreamExt, TryStreamExt};
33use futures_async_stream::try_stream;
34use itertools::Itertools;
35use plotters::backend::SVGBackend;
36use plotters::chart::ChartBuilder;
37use plotters::drawing::IntoDrawingArea;
38use plotters::element::{Circle, EmptyElement};
39use plotters::series::{LineSeries, PointSeries};
40use plotters::style::{IntoFont, RED, WHITE};
41use risingwave_common::bitmap::Bitmap;
42use risingwave_common::catalog::ColumnId;
43use risingwave_common::types::DataType;
44use risingwave_connector::dispatch_sink;
45use risingwave_connector::parser::{
46    EncodingProperties, ParserConfig, ProtocolProperties, SpecificParserConfig,
47};
48use risingwave_connector::sink::catalog::{
49    SinkEncode, SinkFormat, SinkFormatDesc, SinkId, SinkType,
50};
51use risingwave_connector::sink::log_store::{
52    LogReader, LogStoreReadItem, LogStoreResult, TruncateOffset,
53};
54use risingwave_connector::sink::mock_coordination_client::MockMetaClient;
55use risingwave_connector::sink::{
56    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_UPSERT, Sink, SinkError, SinkMetaClient, SinkParam,
57    SinkWriterParam, build_sink,
58};
59use risingwave_connector::source::datagen::{
60    DatagenProperties, DatagenSplitEnumerator, DatagenSplitReader,
61};
62use risingwave_connector::source::{
63    Column, SourceContext, SourceEnumeratorContext, SplitEnumerator, SplitReader,
64};
65use risingwave_stream::executor::test_utils::prelude::ColumnDesc;
66use risingwave_stream::executor::{Barrier, Message, MessageStreamItem, StreamExecutorError};
67use sea_orm::DatabaseConnection;
68use serde::{Deserialize, Deserializer};
69use thiserror_ext::AsReport;
70use tokio::sync::oneshot::Sender;
71use tokio::time::sleep;
72
73const CHECKPOINT_INTERVAL: u64 = 1000;
74const THROUGHPUT_METRIC_RECORD_INTERVAL: u64 = 500;
75const BENCH_TIME: u64 = 20;
76const BENCH_TEST: &str = "bench_test";
77
78pub struct MockRangeLogReader {
79    upstreams: BoxStream<'static, MessageStreamItem>,
80    current_epoch: u64,
81    chunk_id: usize,
82    throughput_metric: Option<ThroughputMetric>,
83    stop_rx: tokio::sync::mpsc::Receiver<()>,
84    result_tx: Option<Sender<ThroughputMetric>>,
85}
86
87impl LogReader for MockRangeLogReader {
88    async fn init(&mut self) -> LogStoreResult<()> {
89        self.throughput_metric.as_mut().unwrap().add_metric(0);
90        Ok(())
91    }
92
93    async fn next_item(&mut self) -> LogStoreResult<(u64, LogStoreReadItem)> {
94        tokio::select! {
95            _ = self.stop_rx.recv() => {
96                self.result_tx
97                        .take()
98                        .unwrap()
99                        .send(self.throughput_metric.take().unwrap())
100                        .map_err(|_| anyhow!("Can't send throughput_metric"))?;
101                    futures::future::pending().await
102                },
103            item = self.upstreams.next() => {
104                match item.unwrap().unwrap() {
105                    Message::Barrier(barrier) => {
106                        let prev_epoch = self.current_epoch;
107                        self.current_epoch = barrier.epoch.curr;
108                        Ok((
109                            prev_epoch,
110                            LogStoreReadItem::Barrier {
111                                is_checkpoint: true,
112                                new_vnode_bitmap: None,
113                                is_stop: false,
114                            },
115                        ))
116                    }
117                    Message::Chunk(chunk) => {
118                        self.throughput_metric.as_mut().unwrap().add_metric(chunk.capacity());
119                        self.chunk_id += 1;
120                        Ok((
121                            self.current_epoch,
122                            LogStoreReadItem::StreamChunk {
123                                chunk,
124                                chunk_id: self.chunk_id,
125                            },
126                        ))
127                    }
128                    _ => Err(anyhow!("Can't assert message type".to_owned())),
129                }
130            }
131        }
132    }
133
134    fn truncate(&mut self, _offset: TruncateOffset) -> LogStoreResult<()> {
135        Ok(())
136    }
137
138    async fn rewind(&mut self) -> LogStoreResult<()> {
139        Err(anyhow!("should not call rewind"))
140    }
141
142    async fn start_from(&mut self, _start_offset: Option<u64>) -> LogStoreResult<()> {
143        Ok(())
144    }
145}
146
147impl MockRangeLogReader {
148    fn new(
149        mock_source: MockDatagenSource,
150        throughput_metric: ThroughputMetric,
151        stop_rx: tokio::sync::mpsc::Receiver<()>,
152        result_tx: Sender<ThroughputMetric>,
153    ) -> MockRangeLogReader {
154        MockRangeLogReader {
155            upstreams: mock_source.into_stream().boxed(),
156            current_epoch: 0,
157            chunk_id: 0,
158            throughput_metric: Some(throughput_metric),
159            stop_rx,
160            result_tx: Some(result_tx),
161        }
162    }
163}
164
165struct ThroughputMetric {
166    accumulate_chunk_size: Arc<AtomicU64>,
167    stop_tx: oneshot::Sender<()>,
168    vec_rx: oneshot::Receiver<Vec<u64>>,
169}
170
171impl ThroughputMetric {
172    pub fn new() -> Self {
173        let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
174        let (vec_tx, vec_rx) = oneshot::channel::<Vec<u64>>();
175        let accumulate_chunk_size = Arc::new(AtomicU64::new(0));
176        let accumulate_chunk_size_clone = accumulate_chunk_size.clone();
177        tokio::spawn(async move {
178            let mut chunk_size_list = vec![];
179            loop {
180                tokio::select! {
181                    _ = sleep(tokio::time::Duration::from_millis(
182                        THROUGHPUT_METRIC_RECORD_INTERVAL,
183                    )) => {
184                        chunk_size_list.push(accumulate_chunk_size_clone.load(Ordering::Relaxed));
185                    }
186                    _ = &mut stop_rx => {
187                        vec_tx.send(chunk_size_list).unwrap();
188                        break;
189                    }
190                }
191            }
192        });
193
194        Self {
195            accumulate_chunk_size,
196            stop_tx,
197            vec_rx,
198        }
199    }
200
201    pub fn add_metric(&mut self, chunk_size: usize) {
202        self.accumulate_chunk_size
203            .fetch_add(chunk_size as u64, Ordering::Relaxed);
204    }
205
206    pub async fn print_throughput(self) {
207        self.stop_tx.send(()).unwrap();
208        let throughput_sum_vec = self.vec_rx.await.unwrap();
209        #[allow(clippy::disallowed_methods)]
210        let throughput_vec = throughput_sum_vec
211            .iter()
212            .zip(throughput_sum_vec.iter().skip(1))
213            .map(|(current, next)| (next - current) * 1000 / THROUGHPUT_METRIC_RECORD_INTERVAL)
214            .collect_vec();
215        if throughput_vec.is_empty() {
216            println!("Throughput Sink: Don't get Throughput, please check");
217            return;
218        }
219        let avg = throughput_vec.iter().sum::<u64>() / throughput_vec.len() as u64;
220        let throughput_vec_sorted = throughput_vec.iter().sorted().collect_vec();
221        let p90 = throughput_vec_sorted[throughput_vec_sorted.len() * 90 / 100];
222        let p95 = throughput_vec_sorted[throughput_vec_sorted.len() * 95 / 100];
223        let p99 = throughput_vec_sorted[throughput_vec_sorted.len() * 99 / 100];
224        println!("Throughput Sink:");
225        println!("avg: {:?} rows/s ", avg);
226        println!("p90: {:?} rows/s ", p90);
227        println!("p95: {:?} rows/s ", p95);
228        println!("p99: {:?} rows/s ", p99);
229        let draw_vec: Vec<(f32, f32)> = throughput_vec
230            .iter()
231            .enumerate()
232            .map(|(index, &value)| {
233                (
234                    (index as f32) * (THROUGHPUT_METRIC_RECORD_INTERVAL as f32 / 1000_f32),
235                    value as f32,
236                )
237            })
238            .collect();
239
240        let root = SVGBackend::new("throughput.svg", (640, 480)).into_drawing_area();
241        root.fill(&WHITE).unwrap();
242        let root = root.margin(10, 10, 10, 10);
243        let mut chart = ChartBuilder::on(&root)
244            .caption("Throughput Sink", ("sans-serif", 40).into_font())
245            .x_label_area_size(20)
246            .y_label_area_size(40)
247            .build_cartesian_2d(
248                0.0..BENCH_TIME as f32,
249                **throughput_vec_sorted.first().unwrap() as f32
250                    ..**throughput_vec_sorted.last().unwrap() as f32,
251            )
252            .unwrap();
253
254        chart
255            .configure_mesh()
256            .x_labels(5)
257            .y_labels(5)
258            .y_label_formatter(&|x| format!("{:.0}", x))
259            .draw()
260            .unwrap();
261
262        chart
263            .draw_series(LineSeries::new(draw_vec.clone(), &RED))
264            .unwrap();
265        chart
266            .draw_series(PointSeries::of_element(draw_vec, 5, &RED, &|c, s, st| {
267                EmptyElement::at(c) + Circle::new((0, 0), s, st.filled())
268            }))
269            .unwrap();
270        root.present().unwrap();
271
272        println!(
273            "Throughput Sink: {:?}",
274            throughput_vec
275                .iter()
276                .map(|a| format!("{} rows/s", a))
277                .collect_vec()
278        );
279    }
280}
281
282pub struct MockDatagenSource {
283    datagen_split_readers: Vec<DatagenSplitReader>,
284}
285impl MockDatagenSource {
286    pub async fn new(
287        rows_per_second: u64,
288        source_schema: Vec<Column>,
289        split_num: String,
290    ) -> MockDatagenSource {
291        let properties = DatagenProperties {
292            split_num: Some(split_num),
293            rows_per_second,
294            fields: HashMap::default(),
295        };
296        let mut datagen_enumerator = DatagenSplitEnumerator::new(
297            properties.clone(),
298            SourceEnumeratorContext::dummy().into(),
299        )
300        .await
301        .unwrap();
302        let parser_config = ParserConfig {
303            specific: SpecificParserConfig {
304                encoding_config: EncodingProperties::Native,
305                protocol_config: ProtocolProperties::Native,
306            },
307            ..Default::default()
308        };
309        let mut datagen_split_readers = vec![];
310        let mut datagen_splits = datagen_enumerator.list_splits().await.unwrap();
311        while let Some(splits) = datagen_splits.pop() {
312            datagen_split_readers.push(
313                DatagenSplitReader::new(
314                    properties.clone(),
315                    vec![splits],
316                    parser_config.clone(),
317                    SourceContext::dummy().into(),
318                    Some(source_schema.clone()),
319                )
320                .await
321                .unwrap(),
322            );
323        }
324        MockDatagenSource {
325            datagen_split_readers,
326        }
327    }
328
329    #[try_stream(ok = Message, error = StreamExecutorError)]
330    pub async fn source_to_data_stream(mut self) {
331        let mut readers = vec![];
332        while let Some(reader) = self.datagen_split_readers.pop() {
333            readers.push(reader.into_stream());
334        }
335        loop {
336            for i in &mut readers {
337                let item = i.next().await.unwrap().unwrap();
338                yield Message::Chunk(item);
339            }
340        }
341    }
342
343    #[try_stream(ok = Message, error = StreamExecutorError)]
344    pub async fn into_stream(self) {
345        let stream = select_with_strategy(
346            Self::barrier_to_message_stream().map_ok(Either::Left),
347            self.source_to_data_stream().map_ok(Either::Right),
348            |_: &mut PollNext| PollNext::Left,
349        );
350        #[for_await]
351        for message in stream {
352            match message.unwrap() {
353                Either::Left(Message::Barrier(barrier)) => {
354                    yield Message::Barrier(barrier);
355                }
356                Either::Right(Message::Chunk(chunk)) => yield Message::Chunk(chunk),
357                _ => {
358                    return Err(StreamExecutorError::from(
359                        "Can't assert message type".to_owned(),
360                    ));
361                }
362            }
363        }
364    }
365
366    #[try_stream(ok = Message, error = StreamExecutorError)]
367    pub async fn barrier_to_message_stream() {
368        let mut epoch = 0_u64;
369        loop {
370            let prev_epoch = epoch;
371            epoch += 1;
372            let barrier = Barrier::with_prev_epoch_for_test(epoch, prev_epoch);
373            yield Message::Barrier(barrier);
374            sleep(tokio::time::Duration::from_millis(CHECKPOINT_INTERVAL)).await;
375        }
376    }
377}
378
379async fn consume_log_stream<S: Sink>(
380    sink: S,
381    mut log_reader: MockRangeLogReader,
382    mut sink_writer_param: SinkWriterParam,
383) -> Result<(), String>
384where
385    <S as risingwave_connector::sink::Sink>::Coordinator: std::marker::Send,
386    <S as risingwave_connector::sink::Sink>::Coordinator: 'static,
387{
388    if let Ok(coordinator) = sink.new_coordinator(DatabaseConnection::Disconnected).await {
389        sink_writer_param.meta_client = Some(SinkMetaClient::MockMetaClient(MockMetaClient::new(
390            Box::new(coordinator),
391        )));
392        sink_writer_param.vnode_bitmap = Some(Bitmap::ones(1));
393    }
394    let log_sinker = sink.new_log_sinker(sink_writer_param).await.unwrap();
395    match log_sinker.consume_log_and_sink(&mut log_reader).await {
396        Ok(_) => Err("Stream closed".to_owned()),
397        Err(e) => Err(e.to_report_string()),
398    }
399}
400
401#[derive(Debug, Deserialize)]
402#[allow(dead_code)]
403struct TableSchemaFromYml {
404    table_name: String,
405    pk_indices: Vec<usize>,
406    columns: Vec<ColumnDescFromYml>,
407}
408
409impl TableSchemaFromYml {
410    pub fn get_source_schema(&self) -> Vec<Column> {
411        self.columns
412            .iter()
413            .map(|column| Column {
414                name: column.name.clone(),
415                data_type: column.r#type.clone(),
416                is_visible: true,
417            })
418            .collect()
419    }
420
421    pub fn get_sink_schema(&self) -> Vec<ColumnDesc> {
422        self.columns
423            .iter()
424            .map(|column| {
425                ColumnDesc::named(column.name.clone(), ColumnId::new(1), column.r#type.clone())
426            })
427            .collect()
428    }
429}
430#[derive(Debug, Deserialize)]
431struct ColumnDescFromYml {
432    name: String,
433    #[serde(deserialize_with = "deserialize_datatype")]
434    r#type: DataType,
435}
436
437fn deserialize_datatype<'de, D>(deserializer: D) -> Result<DataType, D::Error>
438where
439    D: Deserializer<'de>,
440{
441    let s: &str = Deserialize::deserialize(deserializer)?;
442    DataType::from_str(s).map_err(serde::de::Error::custom)
443}
444
445fn read_table_schema_from_yml(path: &str) -> TableSchemaFromYml {
446    let data = std::fs::read_to_string(path).unwrap();
447    let table: TableSchemaFromYml = serde_yaml::from_str(&data).unwrap();
448    table
449}
450
451fn read_sink_option_from_yml(path: &str) -> HashMap<String, BTreeMap<String, String>> {
452    let data = std::fs::read_to_string(path).unwrap();
453    let sink_option: HashMap<String, BTreeMap<String, String>> =
454        serde_yaml::from_str(&data).unwrap();
455    sink_option
456}
457
458#[derive(Parser, Debug)]
459pub struct Config {
460    #[clap(long, default_value = "./sink_bench/schema.yml")]
461    schema_path: String,
462
463    #[clap(short, long, default_value = "./sink_bench/sink_option.yml")]
464    option_path: String,
465
466    #[clap(short, long, default_value = BENCH_TEST)]
467    sink: String,
468
469    #[clap(short, long)]
470    rows_per_second: u64,
471
472    #[clap(long, default_value = "10")]
473    split_num: String,
474}
475
476fn mock_from_legacy_type(
477    connector: &str,
478    r#type: &str,
479) -> Result<Option<SinkFormatDesc>, SinkError> {
480    use risingwave_connector::sink::Sink as _;
481    use risingwave_connector::sink::redis::RedisSink;
482    if connector.eq(RedisSink::SINK_NAME) {
483        let format = match r#type {
484            SINK_TYPE_APPEND_ONLY => SinkFormat::AppendOnly,
485            SINK_TYPE_UPSERT => SinkFormat::Upsert,
486            _ => {
487                return Err(SinkError::Config(anyhow!(
488                    "sink type unsupported: {}",
489                    r#type
490                )));
491            }
492        };
493        Ok(Some(SinkFormatDesc {
494            format,
495            encode: SinkEncode::Json,
496            options: Default::default(),
497            secret_refs: Default::default(),
498            key_encode: None,
499            connection_id: None,
500        }))
501    } else {
502        SinkFormatDesc::from_legacy_type(connector, r#type)
503    }
504}
505
506#[tokio::main]
507async fn main() {
508    let cfg = Config::parse();
509    let table_schema = read_table_schema_from_yml(&cfg.schema_path);
510    let mock_datagen_source = MockDatagenSource::new(
511        cfg.rows_per_second,
512        table_schema.get_source_schema(),
513        cfg.split_num,
514    )
515    .await;
516    let (data_size_tx, data_size_rx) = tokio::sync::oneshot::channel::<ThroughputMetric>();
517    let (stop_tx, stop_rx) = tokio::sync::mpsc::channel::<()>(5);
518    let throughput_metric = ThroughputMetric::new();
519
520    let mut mock_range_log_reader = MockRangeLogReader::new(
521        mock_datagen_source,
522        throughput_metric,
523        stop_rx,
524        data_size_tx,
525    );
526    if cfg.sink.eq(&BENCH_TEST.to_owned()) {
527        println!("Start Sink Bench!, Wait {:?}s", BENCH_TIME);
528        tokio::spawn(async move {
529            mock_range_log_reader.init().await.unwrap();
530            loop {
531                mock_range_log_reader.next_item().await.unwrap();
532            }
533        });
534    } else {
535        let properties = read_sink_option_from_yml(&cfg.option_path)
536            .get(&cfg.sink)
537            .expect("Sink type error")
538            .clone();
539
540        let connector = properties.get("connector").unwrap().clone();
541        let format_desc = mock_from_legacy_type(
542            &connector.clone(),
543            properties.get("type").unwrap_or(&"append-only".to_owned()),
544        )
545        .unwrap();
546        let sink_param = SinkParam {
547            sink_id: SinkId::new(1),
548            sink_name: cfg.sink.clone(),
549            properties,
550            columns: table_schema.get_sink_schema(),
551            downstream_pk: table_schema.pk_indices,
552            sink_type: SinkType::AppendOnly,
553            format_desc,
554            db_name: "not_need_set".to_owned(),
555            sink_from_name: "not_need_set".to_owned(),
556        };
557        let sink = build_sink(sink_param).unwrap();
558        let sink_writer_param = SinkWriterParam::for_test();
559        println!("Start Sink Bench!, Wait {:?}s", BENCH_TIME);
560        tokio::spawn(async move {
561            dispatch_sink!(sink, sink, {
562                consume_log_stream(*sink, mock_range_log_reader, sink_writer_param).boxed()
563            })
564            .await
565            .unwrap();
566        });
567    }
568    sleep(tokio::time::Duration::from_secs(BENCH_TIME)).await;
569    println!("Bench Over!");
570    stop_tx.send(()).await.unwrap();
571    data_size_rx.await.unwrap().print_throughput().await;
572}