risingwave_connector/source/
common.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use futures::{Stream, StreamExt, TryStreamExt};
18use futures_async_stream::try_stream;
19use risingwave_common::array::StreamChunk;
20
21use crate::error::{ConnectorError, ConnectorResult};
22use crate::parser::ParserConfig;
23use crate::source::{
24    BoxSourceMessageEventStream, SourceContextRef, SourceMessage, SourceMessageEvent,
25    SourceReaderEvent,
26};
27
28/// Utility function to convert [`SourceMessage`] stream (got from specific connector's [`SplitReader`](super::SplitReader))
29/// into [`StreamChunk`] stream (by invoking [`ByteStreamSourceParserImpl`](crate::parser::ByteStreamSourceParserImpl)).
30#[try_stream(boxed, ok = StreamChunk, error = ConnectorError)]
31pub(crate) async fn into_chunk_stream(
32    data_stream: impl Stream<Item = ConnectorResult<Vec<SourceMessage>>> + Send + 'static,
33    parser_config: ParserConfig,
34    source_ctx: SourceContextRef,
35) {
36    let event_stream = into_chunk_event_stream(
37        data_stream.map_ok(SourceMessageEvent::Data),
38        parser_config,
39        source_ctx,
40    )
41    .try_filter_map(|event| async move {
42        Ok(match event {
43            SourceReaderEvent::DataChunk(chunk) => Some(chunk),
44            SourceReaderEvent::SplitProgress(_) => None,
45        })
46    });
47    #[for_await]
48    for chunk in event_stream {
49        yield chunk?;
50    }
51}
52
53#[try_stream(boxed, ok = SourceReaderEvent, error = ConnectorError)]
54pub(crate) async fn into_chunk_event_stream(
55    data_stream: impl Stream<Item = ConnectorResult<SourceMessageEvent>> + Send + 'static,
56    parser_config: ParserConfig,
57    source_ctx: SourceContextRef,
58) {
59    let actor_id = source_ctx.actor_id.to_string();
60    let fragment_id = source_ctx.fragment_id.to_string();
61    let source_id = source_ctx.source_id.to_string();
62    let source_name = source_ctx.source_name.clone();
63    let metrics = source_ctx.metrics.clone();
64    let mut partition_input_count = HashMap::new();
65    let mut partition_bytes_count = HashMap::new();
66
67    // add metrics to the data stream
68    let data_stream = data_stream
69        .inspect_ok(move |event| {
70            let SourceMessageEvent::Data(data_batch) = event else {
71                return;
72            };
73
74            let mut by_split_id = std::collections::HashMap::new();
75
76            for msg in data_batch {
77                let split_id: String = msg.split_id.as_ref().to_owned();
78                by_split_id
79                    .entry(split_id.clone())
80                    .or_insert_with(Vec::new)
81                    .push(msg);
82                partition_input_count
83                    .entry(split_id.clone())
84                    .or_insert_with(|| {
85                        metrics.partition_input_count.with_guarded_label_values(&[
86                            &actor_id,
87                            &source_id,
88                            &split_id.clone(),
89                            &source_name,
90                            &fragment_id,
91                        ])
92                    });
93                partition_bytes_count
94                    .entry(split_id.clone())
95                    .or_insert_with(|| {
96                        metrics.partition_input_bytes.with_guarded_label_values(&[
97                            &actor_id,
98                            &source_id,
99                            &split_id,
100                            &source_name,
101                            &fragment_id,
102                        ])
103                    });
104            }
105            for (split_id, msgs) in by_split_id {
106                partition_input_count
107                    .get_mut(&split_id)
108                    .unwrap()
109                    .inc_by(msgs.len() as u64);
110
111                let sum_bytes = msgs
112                    .iter()
113                    .flat_map(|msg| msg.payload.as_ref().map(|p| p.len() as u64))
114                    .sum();
115
116                partition_bytes_count
117                    .get_mut(&split_id)
118                    .unwrap()
119                    .inc_by(sum_bytes);
120            }
121        })
122        .boxed();
123    let data_stream: BoxSourceMessageEventStream = data_stream;
124
125    let parser =
126        crate::parser::ByteStreamSourceParserImpl::create(parser_config, source_ctx).await?;
127    #[for_await]
128    for event in parser.parse_stream_with_events(data_stream) {
129        yield event?;
130    }
131}