1#![feature(coroutines)]
15#![feature(proc_macro_hygiene)]
16#![feature(stmt_expr_attributes)]
17#![feature(let_chains)]
18#![recursion_limit = "256"]
19
20use core::str::FromStr;
21use core::sync::atomic::Ordering;
22use std::collections::{BTreeMap, HashMap};
23use std::sync::Arc;
24use std::sync::atomic::AtomicU64;
25
26use anyhow::anyhow;
27use clap::Parser;
28use futures::channel::oneshot;
29use futures::prelude::future::Either;
30use futures::prelude::stream::{BoxStream, PollNext};
31use futures::stream::select_with_strategy;
32use futures::{FutureExt, StreamExt, TryStreamExt};
33use futures_async_stream::try_stream;
34use itertools::Itertools;
35use plotters::backend::SVGBackend;
36use plotters::chart::ChartBuilder;
37use plotters::drawing::IntoDrawingArea;
38use plotters::element::{Circle, EmptyElement};
39use plotters::series::{LineSeries, PointSeries};
40use plotters::style::{IntoFont, RED, WHITE};
41use risingwave_common::bitmap::Bitmap;
42use risingwave_common::catalog::ColumnId;
43use risingwave_common::types::DataType;
44use risingwave_connector::dispatch_sink;
45use risingwave_connector::parser::{
46 EncodingProperties, ParserConfig, ProtocolProperties, SpecificParserConfig,
47};
48use risingwave_connector::sink::catalog::{
49 SinkEncode, SinkFormat, SinkFormatDesc, SinkId, SinkType,
50};
51use risingwave_connector::sink::log_store::{
52 LogReader, LogStoreReadItem, LogStoreResult, TruncateOffset,
53};
54use risingwave_connector::sink::mock_coordination_client::MockMetaClient;
55use risingwave_connector::sink::{
56 LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_UPSERT, Sink, SinkError, SinkMetaClient, SinkParam,
57 SinkWriterParam, build_sink,
58};
59use risingwave_connector::source::datagen::{
60 DatagenProperties, DatagenSplitEnumerator, DatagenSplitReader,
61};
62use risingwave_connector::source::{
63 Column, SourceContext, SourceEnumeratorContext, SplitEnumerator, SplitReader,
64};
65use risingwave_stream::executor::test_utils::prelude::ColumnDesc;
66use risingwave_stream::executor::{Barrier, Message, MessageStreamItem, StreamExecutorError};
67use sea_orm::DatabaseConnection;
68use serde::{Deserialize, Deserializer};
69use thiserror_ext::AsReport;
70use tokio::sync::oneshot::Sender;
71use tokio::time::sleep;
72
73const CHECKPOINT_INTERVAL: u64 = 1000;
74const THROUGHPUT_METRIC_RECORD_INTERVAL: u64 = 500;
75const BENCH_TIME: u64 = 20;
76const BENCH_TEST: &str = "bench_test";
77
78pub struct MockRangeLogReader {
79 upstreams: BoxStream<'static, MessageStreamItem>,
80 current_epoch: u64,
81 chunk_id: usize,
82 throughput_metric: Option<ThroughputMetric>,
83 stop_rx: tokio::sync::mpsc::Receiver<()>,
84 result_tx: Option<Sender<ThroughputMetric>>,
85}
86
87impl LogReader for MockRangeLogReader {
88 async fn init(&mut self) -> LogStoreResult<()> {
89 self.throughput_metric.as_mut().unwrap().add_metric(0);
90 Ok(())
91 }
92
93 async fn next_item(&mut self) -> LogStoreResult<(u64, LogStoreReadItem)> {
94 tokio::select! {
95 _ = self.stop_rx.recv() => {
96 self.result_tx
97 .take()
98 .unwrap()
99 .send(self.throughput_metric.take().unwrap())
100 .map_err(|_| anyhow!("Can't send throughput_metric"))?;
101 futures::future::pending().await
102 },
103 item = self.upstreams.next() => {
104 match item.unwrap().unwrap() {
105 Message::Barrier(barrier) => {
106 let prev_epoch = self.current_epoch;
107 self.current_epoch = barrier.epoch.curr;
108 Ok((
109 prev_epoch,
110 LogStoreReadItem::Barrier {
111 is_checkpoint: true,
112 new_vnode_bitmap: None,
113 is_stop: false,
114 },
115 ))
116 }
117 Message::Chunk(chunk) => {
118 self.throughput_metric.as_mut().unwrap().add_metric(chunk.capacity());
119 self.chunk_id += 1;
120 Ok((
121 self.current_epoch,
122 LogStoreReadItem::StreamChunk {
123 chunk,
124 chunk_id: self.chunk_id,
125 },
126 ))
127 }
128 _ => Err(anyhow!("Can't assert message type".to_owned())),
129 }
130 }
131 }
132 }
133
134 fn truncate(&mut self, _offset: TruncateOffset) -> LogStoreResult<()> {
135 Ok(())
136 }
137
138 async fn rewind(&mut self) -> LogStoreResult<()> {
139 Err(anyhow!("should not call rewind"))
140 }
141
142 async fn start_from(&mut self, _start_offset: Option<u64>) -> LogStoreResult<()> {
143 Ok(())
144 }
145}
146
147impl MockRangeLogReader {
148 fn new(
149 mock_source: MockDatagenSource,
150 throughput_metric: ThroughputMetric,
151 stop_rx: tokio::sync::mpsc::Receiver<()>,
152 result_tx: Sender<ThroughputMetric>,
153 ) -> MockRangeLogReader {
154 MockRangeLogReader {
155 upstreams: mock_source.into_stream().boxed(),
156 current_epoch: 0,
157 chunk_id: 0,
158 throughput_metric: Some(throughput_metric),
159 stop_rx,
160 result_tx: Some(result_tx),
161 }
162 }
163}
164
165struct ThroughputMetric {
166 accumulate_chunk_size: Arc<AtomicU64>,
167 stop_tx: oneshot::Sender<()>,
168 vec_rx: oneshot::Receiver<Vec<u64>>,
169}
170
171impl ThroughputMetric {
172 pub fn new() -> Self {
173 let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
174 let (vec_tx, vec_rx) = oneshot::channel::<Vec<u64>>();
175 let accumulate_chunk_size = Arc::new(AtomicU64::new(0));
176 let accumulate_chunk_size_clone = accumulate_chunk_size.clone();
177 tokio::spawn(async move {
178 let mut chunk_size_list = vec![];
179 loop {
180 tokio::select! {
181 _ = sleep(tokio::time::Duration::from_millis(
182 THROUGHPUT_METRIC_RECORD_INTERVAL,
183 )) => {
184 chunk_size_list.push(accumulate_chunk_size_clone.load(Ordering::Relaxed));
185 }
186 _ = &mut stop_rx => {
187 vec_tx.send(chunk_size_list).unwrap();
188 break;
189 }
190 }
191 }
192 });
193
194 Self {
195 accumulate_chunk_size,
196 stop_tx,
197 vec_rx,
198 }
199 }
200
201 pub fn add_metric(&mut self, chunk_size: usize) {
202 self.accumulate_chunk_size
203 .fetch_add(chunk_size as u64, Ordering::Relaxed);
204 }
205
206 pub async fn print_throughput(self) {
207 self.stop_tx.send(()).unwrap();
208 let throughput_sum_vec = self.vec_rx.await.unwrap();
209 #[allow(clippy::disallowed_methods)]
210 let throughput_vec = throughput_sum_vec
211 .iter()
212 .zip(throughput_sum_vec.iter().skip(1))
213 .map(|(current, next)| (next - current) * 1000 / THROUGHPUT_METRIC_RECORD_INTERVAL)
214 .collect_vec();
215 if throughput_vec.is_empty() {
216 println!("Throughput Sink: Don't get Throughput, please check");
217 return;
218 }
219 let avg = throughput_vec.iter().sum::<u64>() / throughput_vec.len() as u64;
220 let throughput_vec_sorted = throughput_vec.iter().sorted().collect_vec();
221 let p90 = throughput_vec_sorted[throughput_vec_sorted.len() * 90 / 100];
222 let p95 = throughput_vec_sorted[throughput_vec_sorted.len() * 95 / 100];
223 let p99 = throughput_vec_sorted[throughput_vec_sorted.len() * 99 / 100];
224 println!("Throughput Sink:");
225 println!("avg: {:?} rows/s ", avg);
226 println!("p90: {:?} rows/s ", p90);
227 println!("p95: {:?} rows/s ", p95);
228 println!("p99: {:?} rows/s ", p99);
229 let draw_vec: Vec<(f32, f32)> = throughput_vec
230 .iter()
231 .enumerate()
232 .map(|(index, &value)| {
233 (
234 (index as f32) * (THROUGHPUT_METRIC_RECORD_INTERVAL as f32 / 1000_f32),
235 value as f32,
236 )
237 })
238 .collect();
239
240 let root = SVGBackend::new("throughput.svg", (640, 480)).into_drawing_area();
241 root.fill(&WHITE).unwrap();
242 let root = root.margin(10, 10, 10, 10);
243 let mut chart = ChartBuilder::on(&root)
244 .caption("Throughput Sink", ("sans-serif", 40).into_font())
245 .x_label_area_size(20)
246 .y_label_area_size(40)
247 .build_cartesian_2d(
248 0.0..BENCH_TIME as f32,
249 **throughput_vec_sorted.first().unwrap() as f32
250 ..**throughput_vec_sorted.last().unwrap() as f32,
251 )
252 .unwrap();
253
254 chart
255 .configure_mesh()
256 .x_labels(5)
257 .y_labels(5)
258 .y_label_formatter(&|x| format!("{:.0}", x))
259 .draw()
260 .unwrap();
261
262 chart
263 .draw_series(LineSeries::new(draw_vec.clone(), &RED))
264 .unwrap();
265 chart
266 .draw_series(PointSeries::of_element(draw_vec, 5, &RED, &|c, s, st| {
267 EmptyElement::at(c) + Circle::new((0, 0), s, st.filled())
268 }))
269 .unwrap();
270 root.present().unwrap();
271
272 println!(
273 "Throughput Sink: {:?}",
274 throughput_vec
275 .iter()
276 .map(|a| format!("{} rows/s", a))
277 .collect_vec()
278 );
279 }
280}
281
282pub struct MockDatagenSource {
283 datagen_split_readers: Vec<DatagenSplitReader>,
284}
285impl MockDatagenSource {
286 pub async fn new(
287 rows_per_second: u64,
288 source_schema: Vec<Column>,
289 split_num: String,
290 ) -> MockDatagenSource {
291 let properties = DatagenProperties {
292 split_num: Some(split_num),
293 rows_per_second,
294 fields: HashMap::default(),
295 };
296 let mut datagen_enumerator = DatagenSplitEnumerator::new(
297 properties.clone(),
298 SourceEnumeratorContext::dummy().into(),
299 )
300 .await
301 .unwrap();
302 let parser_config = ParserConfig {
303 specific: SpecificParserConfig {
304 encoding_config: EncodingProperties::Native,
305 protocol_config: ProtocolProperties::Native,
306 },
307 ..Default::default()
308 };
309 let mut datagen_split_readers = vec![];
310 let mut datagen_splits = datagen_enumerator.list_splits().await.unwrap();
311 while let Some(splits) = datagen_splits.pop() {
312 datagen_split_readers.push(
313 DatagenSplitReader::new(
314 properties.clone(),
315 vec![splits],
316 parser_config.clone(),
317 SourceContext::dummy().into(),
318 Some(source_schema.clone()),
319 )
320 .await
321 .unwrap(),
322 );
323 }
324 MockDatagenSource {
325 datagen_split_readers,
326 }
327 }
328
329 #[try_stream(ok = Message, error = StreamExecutorError)]
330 pub async fn source_to_data_stream(mut self) {
331 let mut readers = vec![];
332 while let Some(reader) = self.datagen_split_readers.pop() {
333 readers.push(reader.into_stream());
334 }
335 loop {
336 for i in &mut readers {
337 let item = i.next().await.unwrap().unwrap();
338 yield Message::Chunk(item);
339 }
340 }
341 }
342
343 #[try_stream(ok = Message, error = StreamExecutorError)]
344 pub async fn into_stream(self) {
345 let stream = select_with_strategy(
346 Self::barrier_to_message_stream().map_ok(Either::Left),
347 self.source_to_data_stream().map_ok(Either::Right),
348 |_: &mut PollNext| PollNext::Left,
349 );
350 #[for_await]
351 for message in stream {
352 match message.unwrap() {
353 Either::Left(Message::Barrier(barrier)) => {
354 yield Message::Barrier(barrier);
355 }
356 Either::Right(Message::Chunk(chunk)) => yield Message::Chunk(chunk),
357 _ => {
358 return Err(StreamExecutorError::from(
359 "Can't assert message type".to_owned(),
360 ));
361 }
362 }
363 }
364 }
365
366 #[try_stream(ok = Message, error = StreamExecutorError)]
367 pub async fn barrier_to_message_stream() {
368 let mut epoch = 0_u64;
369 loop {
370 let prev_epoch = epoch;
371 epoch += 1;
372 let barrier = Barrier::with_prev_epoch_for_test(epoch, prev_epoch);
373 yield Message::Barrier(barrier);
374 sleep(tokio::time::Duration::from_millis(CHECKPOINT_INTERVAL)).await;
375 }
376 }
377}
378
379async fn consume_log_stream<S: Sink>(
380 sink: S,
381 mut log_reader: MockRangeLogReader,
382 mut sink_writer_param: SinkWriterParam,
383) -> Result<(), String>
384where
385 <S as risingwave_connector::sink::Sink>::Coordinator: std::marker::Send,
386 <S as risingwave_connector::sink::Sink>::Coordinator: 'static,
387{
388 if let Ok(coordinator) = sink.new_coordinator(DatabaseConnection::Disconnected).await {
389 sink_writer_param.meta_client = Some(SinkMetaClient::MockMetaClient(MockMetaClient::new(
390 Box::new(coordinator),
391 )));
392 sink_writer_param.vnode_bitmap = Some(Bitmap::ones(1));
393 }
394 let log_sinker = sink.new_log_sinker(sink_writer_param).await.unwrap();
395 match log_sinker.consume_log_and_sink(&mut log_reader).await {
396 Ok(_) => Err("Stream closed".to_owned()),
397 Err(e) => Err(e.to_report_string()),
398 }
399}
400
401#[derive(Debug, Deserialize)]
402#[allow(dead_code)]
403struct TableSchemaFromYml {
404 table_name: String,
405 pk_indices: Vec<usize>,
406 columns: Vec<ColumnDescFromYml>,
407}
408
409impl TableSchemaFromYml {
410 pub fn get_source_schema(&self) -> Vec<Column> {
411 self.columns
412 .iter()
413 .map(|column| Column {
414 name: column.name.clone(),
415 data_type: column.r#type.clone(),
416 is_visible: true,
417 })
418 .collect()
419 }
420
421 pub fn get_sink_schema(&self) -> Vec<ColumnDesc> {
422 self.columns
423 .iter()
424 .map(|column| {
425 ColumnDesc::named(column.name.clone(), ColumnId::new(1), column.r#type.clone())
426 })
427 .collect()
428 }
429}
430#[derive(Debug, Deserialize)]
431struct ColumnDescFromYml {
432 name: String,
433 #[serde(deserialize_with = "deserialize_datatype")]
434 r#type: DataType,
435}
436
437fn deserialize_datatype<'de, D>(deserializer: D) -> Result<DataType, D::Error>
438where
439 D: Deserializer<'de>,
440{
441 let s: &str = Deserialize::deserialize(deserializer)?;
442 DataType::from_str(s).map_err(serde::de::Error::custom)
443}
444
445fn read_table_schema_from_yml(path: &str) -> TableSchemaFromYml {
446 let data = std::fs::read_to_string(path).unwrap();
447 let table: TableSchemaFromYml = serde_yaml::from_str(&data).unwrap();
448 table
449}
450
451fn read_sink_option_from_yml(path: &str) -> HashMap<String, BTreeMap<String, String>> {
452 let data = std::fs::read_to_string(path).unwrap();
453 let sink_option: HashMap<String, BTreeMap<String, String>> =
454 serde_yaml::from_str(&data).unwrap();
455 sink_option
456}
457
458#[derive(Parser, Debug)]
459pub struct Config {
460 #[clap(long, default_value = "./sink_bench/schema.yml")]
461 schema_path: String,
462
463 #[clap(short, long, default_value = "./sink_bench/sink_option.yml")]
464 option_path: String,
465
466 #[clap(short, long, default_value = BENCH_TEST)]
467 sink: String,
468
469 #[clap(short, long)]
470 rows_per_second: u64,
471
472 #[clap(long, default_value = "10")]
473 split_num: String,
474}
475
476fn mock_from_legacy_type(
477 connector: &str,
478 r#type: &str,
479) -> Result<Option<SinkFormatDesc>, SinkError> {
480 use risingwave_connector::sink::Sink as _;
481 use risingwave_connector::sink::redis::RedisSink;
482 if connector.eq(RedisSink::SINK_NAME) {
483 let format = match r#type {
484 SINK_TYPE_APPEND_ONLY => SinkFormat::AppendOnly,
485 SINK_TYPE_UPSERT => SinkFormat::Upsert,
486 _ => {
487 return Err(SinkError::Config(anyhow!(
488 "sink type unsupported: {}",
489 r#type
490 )));
491 }
492 };
493 Ok(Some(SinkFormatDesc {
494 format,
495 encode: SinkEncode::Json,
496 options: Default::default(),
497 secret_refs: Default::default(),
498 key_encode: None,
499 connection_id: None,
500 }))
501 } else {
502 SinkFormatDesc::from_legacy_type(connector, r#type)
503 }
504}
505
506#[tokio::main]
507async fn main() {
508 let cfg = Config::parse();
509 let table_schema = read_table_schema_from_yml(&cfg.schema_path);
510 let mock_datagen_source = MockDatagenSource::new(
511 cfg.rows_per_second,
512 table_schema.get_source_schema(),
513 cfg.split_num,
514 )
515 .await;
516 let (data_size_tx, data_size_rx) = tokio::sync::oneshot::channel::<ThroughputMetric>();
517 let (stop_tx, stop_rx) = tokio::sync::mpsc::channel::<()>(5);
518 let throughput_metric = ThroughputMetric::new();
519
520 let mut mock_range_log_reader = MockRangeLogReader::new(
521 mock_datagen_source,
522 throughput_metric,
523 stop_rx,
524 data_size_tx,
525 );
526 if cfg.sink.eq(&BENCH_TEST.to_owned()) {
527 println!("Start Sink Bench!, Wait {:?}s", BENCH_TIME);
528 tokio::spawn(async move {
529 mock_range_log_reader.init().await.unwrap();
530 loop {
531 mock_range_log_reader.next_item().await.unwrap();
532 }
533 });
534 } else {
535 let properties = read_sink_option_from_yml(&cfg.option_path)
536 .get(&cfg.sink)
537 .expect("Sink type error")
538 .clone();
539
540 let connector = properties.get("connector").unwrap().clone();
541 let format_desc = mock_from_legacy_type(
542 &connector.clone(),
543 properties.get("type").unwrap_or(&"append-only".to_owned()),
544 )
545 .unwrap();
546 let sink_param = SinkParam {
547 sink_id: SinkId::new(1),
548 sink_name: cfg.sink.clone(),
549 properties,
550 columns: table_schema.get_sink_schema(),
551 downstream_pk: table_schema.pk_indices,
552 sink_type: SinkType::AppendOnly,
553 format_desc,
554 db_name: "not_need_set".to_owned(),
555 sink_from_name: "not_need_set".to_owned(),
556 };
557 let sink = build_sink(sink_param).unwrap();
558 let sink_writer_param = SinkWriterParam::for_test();
559 println!("Start Sink Bench!, Wait {:?}s", BENCH_TIME);
560 tokio::spawn(async move {
561 dispatch_sink!(sink, sink, {
562 consume_log_stream(*sink, mock_range_log_reader, sink_writer_param).boxed()
563 })
564 .await
565 .unwrap();
566 });
567 }
568 sleep(tokio::time::Duration::from_secs(BENCH_TIME)).await;
569 println!("Bench Over!");
570 stop_tx.send(()).await.unwrap();
571 data_size_rx.await.unwrap().print_throughput().await;
572}