risingwave_connector/source/
common.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use futures::{Stream, StreamExt, TryStreamExt};
18use futures_async_stream::try_stream;
19use risingwave_common::array::StreamChunk;
20
21use crate::error::{ConnectorError, ConnectorResult};
22use crate::parser::ParserConfig;
23use crate::source::{SourceContextRef, SourceMessage};
24
25/// Utility function to convert [`SourceMessage`] stream (got from specific connector's [`SplitReader`](super::SplitReader))
26/// into [`StreamChunk`] stream (by invoking [`ByteStreamSourceParserImpl`](crate::parser::ByteStreamSourceParserImpl)).
27#[try_stream(boxed, ok = StreamChunk, error = ConnectorError)]
28pub(crate) async fn into_chunk_stream(
29    data_stream: impl Stream<Item = ConnectorResult<Vec<SourceMessage>>> + Send + 'static,
30    parser_config: ParserConfig,
31    source_ctx: SourceContextRef,
32) {
33    let actor_id = source_ctx.actor_id.to_string();
34    let fragment_id = source_ctx.fragment_id.to_string();
35    let source_id = source_ctx.source_id.to_string();
36    let source_name = source_ctx.source_name.clone();
37    let metrics = source_ctx.metrics.clone();
38    let mut partition_input_count = HashMap::new();
39    let mut partition_bytes_count = HashMap::new();
40
41    // add metrics to the data stream
42    let data_stream = data_stream
43        .inspect_ok(move |data_batch| {
44            let mut by_split_id = std::collections::HashMap::new();
45
46            for msg in data_batch {
47                let split_id: String = msg.split_id.as_ref().to_owned();
48                by_split_id
49                    .entry(split_id.clone())
50                    .or_insert_with(Vec::new)
51                    .push(msg);
52                partition_input_count
53                    .entry(split_id.clone())
54                    .or_insert_with(|| {
55                        metrics.partition_input_count.with_guarded_label_values(&[
56                            &actor_id,
57                            &source_id,
58                            &split_id.clone(),
59                            &source_name,
60                            &fragment_id,
61                        ])
62                    });
63                partition_bytes_count
64                    .entry(split_id.clone())
65                    .or_insert_with(|| {
66                        metrics.partition_input_bytes.with_guarded_label_values(&[
67                            &actor_id,
68                            &source_id,
69                            &split_id,
70                            &source_name,
71                            &fragment_id,
72                        ])
73                    });
74            }
75            for (split_id, msgs) in by_split_id {
76                partition_input_count
77                    .get_mut(&split_id)
78                    .unwrap()
79                    .inc_by(msgs.len() as u64);
80
81                let sum_bytes = msgs
82                    .iter()
83                    .flat_map(|msg| msg.payload.as_ref().map(|p| p.len() as u64))
84                    .sum();
85
86                partition_bytes_count
87                    .get_mut(&split_id)
88                    .unwrap()
89                    .inc_by(sum_bytes);
90            }
91        })
92        .boxed();
93
94    let parser =
95        crate::parser::ByteStreamSourceParserImpl::create(parser_config, source_ctx).await?;
96    #[for_await]
97    for chunk in parser.parse_stream(data_stream) {
98        yield chunk?;
99    }
100}