risingwave_stream/executor/
upstream_sink_union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use futures::future::try_join_all;
19use pin_project::pin_project;
20use risingwave_expr::expr::NonStrictExpression;
21
22use crate::executor::exchange::input::{Input, new_input};
23use crate::executor::prelude::*;
24use crate::executor::project::apply_project_exprs;
25use crate::executor::{BarrierMutationType, BoxedMessageInput, DynamicReceivers, MergeExecutor};
26use crate::task::{FragmentId, LocalBarrierManager};
27
28type ProcessedMessageStream = impl Stream<Item = MessageStreamItem>;
29
30/// A wrapper that merges data from a single upstream fragment and applies projection expressions.
31/// Each `SinkHandlerInput` represents one upstream fragment with its own merge executor and projection logic.
32#[pin_project]
33pub struct SinkHandlerInput {
34    /// The ID of the upstream fragment that this input is associated with.
35    upstream_fragment_id: FragmentId,
36
37    /// The stream of messages from the upstream fragment.
38    #[pin]
39    processed_stream: ProcessedMessageStream,
40}
41
42impl SinkHandlerInput {
43    pub fn new(
44        upstream_fragment_id: FragmentId,
45        merge: Box<MergeExecutor>,
46        project_exprs: Vec<NonStrictExpression>,
47    ) -> Self {
48        let processed_stream = Self::apply_project_exprs_stream(merge, project_exprs);
49        Self {
50            upstream_fragment_id,
51            processed_stream,
52        }
53    }
54
55    #[define_opaque(ProcessedMessageStream)]
56    fn apply_project_exprs_stream(
57        merge: Box<MergeExecutor>,
58        project_exprs: Vec<NonStrictExpression>,
59    ) -> ProcessedMessageStream {
60        // Apply the projection expressions to the output of the merge executor.
61        Self::apply_project_exprs_stream_impl(merge, project_exprs)
62    }
63
64    /// Applies a projection to the output of a merge executor.
65    #[try_stream(ok = Message, error = StreamExecutorError)]
66    async fn apply_project_exprs_stream_impl(
67        merge: Box<MergeExecutor>,
68        project_exprs: Vec<NonStrictExpression>,
69    ) {
70        let merge_stream = merge.execute_inner();
71        pin_mut!(merge_stream);
72        while let Some(msg) = merge_stream.next().await {
73            let msg = msg?;
74            if let Message::Chunk(chunk) = msg {
75                // Apply the projection expressions to the chunk.
76                let new_chunk = apply_project_exprs(&project_exprs, chunk).await?;
77                yield Message::Chunk(new_chunk);
78            } else {
79                yield msg;
80            }
81        }
82    }
83}
84
85impl Input for SinkHandlerInput {
86    type InputId = FragmentId;
87
88    fn id(&self) -> Self::InputId {
89        // Return a unique identifier for this input, e.g., based on the upstream fragment ID
90        self.upstream_fragment_id
91    }
92}
93
94impl Stream for SinkHandlerInput {
95    type Item = MessageStreamItem;
96
97    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98        self.project().processed_stream.poll_next(cx)
99    }
100}
101
102/// Information about an upstream fragment including its schema and projection expressions.
103#[derive(Debug)]
104pub struct UpstreamInfo {
105    pub upstream_fragment_id: FragmentId,
106    pub merge_schema: Schema,
107    pub project_exprs: Vec<NonStrictExpression>,
108}
109
110type BoxedSinkInput = BoxedMessageInput<FragmentId, BarrierMutationType>;
111
112/// `UpstreamSinkUnionExecutor` merges data from multiple upstream fragments, where each fragment
113/// has its own merge logic and projection expressions. This executor is specifically designed for
114/// sink operations that need to union data from different upstream sources.
115///
116/// Unlike a simple union that just merges streams, this executor:
117/// 1. Creates a separate `MergeExecutor` for each upstream fragment
118/// 2. Applies fragment-specific projection expressions to each stream
119/// 3. Unions all the processed streams into a single output stream
120///
121/// This is useful for sink operators that need to collect data from multiple upstream fragments
122/// with potentially different schemas or processing requirements.
123pub struct UpstreamSinkUnionExecutor {
124    /// The context of the actor.
125    actor_context: ActorContextRef,
126
127    /// Used to create merge executors.
128    local_barrier_manager: LocalBarrierManager,
129
130    /// Streaming metrics.
131    executor_stats: Arc<StreamingMetrics>,
132
133    /// The size of the chunks to be processed.
134    chunk_size: usize,
135
136    /// The initial inputs to the executor.
137    upstream_infos: Vec<UpstreamInfo>,
138}
139
140impl Debug for UpstreamSinkUnionExecutor {
141    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
142        f.debug_struct("UpstreamSinkUnionExecutor")
143            .field("upstream_infos", &self.upstream_infos)
144            .finish()
145    }
146}
147
148impl Execute for UpstreamSinkUnionExecutor {
149    fn execute(self: Box<Self>) -> BoxedMessageStream {
150        self.execute_inner().boxed()
151    }
152}
153
154impl UpstreamSinkUnionExecutor {
155    pub fn new(
156        ctx: ActorContextRef,
157        local_barrier_manager: LocalBarrierManager,
158        executor_stats: Arc<StreamingMetrics>,
159        chunk_size: usize,
160        upstream_infos: Vec<(FragmentId, Schema, Vec<NonStrictExpression>)>,
161    ) -> Self {
162        Self {
163            actor_context: ctx,
164            local_barrier_manager,
165            executor_stats,
166            chunk_size,
167            upstream_infos: upstream_infos
168                .into_iter()
169                .map(
170                    |(upstream_fragment_id, merge_schema, project_exprs)| UpstreamInfo {
171                        upstream_fragment_id,
172                        merge_schema,
173                        project_exprs,
174                    },
175                )
176                .collect(),
177        }
178    }
179
180    #[cfg(test)]
181    pub fn for_test(
182        actor_id: ActorId,
183        local_barrier_manager: LocalBarrierManager,
184        chunk_size: usize,
185    ) -> Self {
186        let metrics = StreamingMetrics::unused();
187        let actor_ctx = ActorContext::for_test(actor_id);
188        Self {
189            actor_context: actor_ctx,
190            local_barrier_manager,
191            executor_stats: metrics.into(),
192            chunk_size,
193            upstream_infos: vec![],
194        }
195    }
196
197    #[allow(dead_code)]
198    async fn new_sink_input(
199        &self,
200        upstream_info: UpstreamInfo,
201    ) -> StreamExecutorResult<BoxedSinkInput> {
202        let (upstream_fragment_id, merge_schema, project_exprs) = (
203            upstream_info.upstream_fragment_id,
204            upstream_info.merge_schema,
205            upstream_info.project_exprs,
206        );
207
208        let merge_executor = self
209            .new_merge_executor(upstream_fragment_id, merge_schema)
210            .await?;
211
212        Ok(SinkHandlerInput::new(
213            upstream_fragment_id,
214            Box::new(merge_executor),
215            project_exprs,
216        )
217        .boxed_input())
218    }
219
220    async fn new_merge_executor(
221        &self,
222        upstream_fragment_id: FragmentId,
223        schema: Schema,
224    ) -> StreamExecutorResult<MergeExecutor> {
225        let barrier_rx = self
226            .local_barrier_manager
227            .subscribe_barrier(self.actor_context.id);
228
229        let inputs: Vec<_> = try_join_all(
230            self.actor_context
231                .initial_upstream_actors
232                .get(&upstream_fragment_id)
233                .map(|actors| actors.actors.iter())
234                .into_iter()
235                .flatten()
236                .map(|upstream_actor| {
237                    new_input(
238                        &self.local_barrier_manager,
239                        self.executor_stats.clone(),
240                        self.actor_context.id,
241                        self.actor_context.fragment_id,
242                        upstream_actor,
243                        upstream_fragment_id,
244                    )
245                }),
246        )
247        .await?;
248
249        let upstreams =
250            MergeExecutor::new_select_receiver(inputs, &self.executor_stats, &self.actor_context);
251
252        Ok(MergeExecutor::new(
253            self.actor_context.clone(),
254            self.actor_context.fragment_id,
255            upstream_fragment_id,
256            upstreams,
257            self.local_barrier_manager.clone(),
258            self.executor_stats.clone(),
259            barrier_rx,
260            self.chunk_size,
261            schema,
262        ))
263    }
264
265    #[try_stream(ok = Message, error = StreamExecutorError)]
266    async fn execute_inner(mut self: Box<Self>) {
267        let inputs: Vec<_> = {
268            let upstream_infos = std::mem::take(&mut self.upstream_infos);
269            let mut inputs = Vec::with_capacity(upstream_infos.len());
270            for UpstreamInfo {
271                upstream_fragment_id,
272                merge_schema,
273                project_exprs,
274            } in upstream_infos
275            {
276                let merge_executor = self
277                    .new_merge_executor(upstream_fragment_id, merge_schema)
278                    .await?;
279
280                let input = SinkHandlerInput::new(
281                    upstream_fragment_id,
282                    Box::new(merge_executor),
283                    project_exprs,
284                )
285                .boxed_input();
286
287                inputs.push(input);
288            }
289            inputs
290        };
291
292        let execution_stream = self.execute_with_inputs(inputs);
293        pin_mut!(execution_stream);
294        while let Some(msg) = execution_stream.next().await {
295            yield msg?;
296        }
297    }
298
299    #[try_stream(ok = Message, error = StreamExecutorError)]
300    async fn execute_with_inputs(self: Box<Self>, inputs: Vec<BoxedSinkInput>) {
301        let actor_id = self.actor_context.id;
302        let fragment_id = self.actor_context.fragment_id;
303
304        let barrier_align = self
305            .executor_stats
306            .barrier_align_duration
307            .with_guarded_label_values(&[
308                actor_id.to_string().as_str(),
309                fragment_id.to_string().as_str(),
310                "",
311                "UpstreamSinkUnion",
312            ]);
313
314        let upstreams = DynamicReceivers::new(inputs, Some(barrier_align), None);
315        pin_mut!(upstreams);
316
317        while let Some(msg) = upstreams.next().await {
318            yield msg?;
319        }
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use risingwave_common::array::{Op, StreamChunkTestExt};
326    use risingwave_common::catalog::Field;
327
328    use super::*;
329    use crate::executor::MessageInner;
330    use crate::executor::exchange::permit::{Sender, channel_for_test};
331    use crate::executor::test_utils::expr::build_from_pretty;
332    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
333
334    #[tokio::test]
335    async fn test_sink_input() {
336        let test_env = LocalBarrierTestEnv::for_test().await;
337
338        let actor_id = 2;
339
340        let b1 = Barrier::with_prev_epoch_for_test(2, 1);
341
342        test_env.inject_barrier(&b1, [actor_id]);
343        test_env.flush_all_events().await;
344
345        let schema = Schema {
346            fields: vec![
347                Field::unnamed(DataType::Int64),
348                Field::unnamed(DataType::Int64),
349            ],
350        };
351
352        let (tx1, rx1) = channel_for_test();
353        let (tx2, rx2) = channel_for_test();
354
355        let merge = MergeExecutor::for_test(
356            actor_id,
357            vec![rx1, rx2],
358            test_env.local_barrier_manager.clone(),
359            schema.clone(),
360            5,
361        );
362
363        let test_expr = build_from_pretty("$1:int8");
364
365        let mut input = SinkHandlerInput::new(
366            1919, // from MergeExecutor::for_test()
367            Box::new(merge),
368            vec![test_expr],
369        )
370        .boxed_input();
371
372        let chunk1 = StreamChunk::from_pretty(
373            " I I
374            + 1 4
375            + 2 5
376            + 3 6",
377        );
378        let chunk2 = StreamChunk::from_pretty(
379            " I I
380            + 7 8
381            - 3 6",
382        );
383
384        tx1.send(MessageInner::Chunk(chunk1).into()).await.unwrap();
385        tx2.send(MessageInner::Chunk(chunk2).into()).await.unwrap();
386
387        let msg = input.next().await.unwrap().unwrap();
388        assert_eq!(
389            *msg.as_chunk().unwrap(),
390            StreamChunk::from_pretty(
391                " I
392                + 4
393                + 5
394                + 6
395                + 8
396                - 6"
397            )
398        );
399    }
400
401    fn new_input_for_test(
402        actor_id: ActorId,
403        local_barrier_manager: LocalBarrierManager,
404    ) -> (BoxedSinkInput, Sender) {
405        let (tx, rx) = channel_for_test();
406        let merge = MergeExecutor::for_test(
407            actor_id,
408            vec![rx],
409            local_barrier_manager,
410            Schema::new(vec![]),
411            10,
412        );
413        let input = SinkHandlerInput::new(actor_id, Box::new(merge), vec![]).boxed_input();
414        (input, tx)
415    }
416
417    fn build_test_chunk(size: u64) -> StreamChunk {
418        let ops = vec![Op::Insert; size as usize];
419        StreamChunk::new(ops, vec![])
420    }
421
422    #[tokio::test]
423    async fn test_fixed_upstreams() {
424        let test_env = LocalBarrierTestEnv::for_test().await;
425
426        let actor_id = 2;
427
428        let b1 = Barrier::with_prev_epoch_for_test(2, 1);
429
430        test_env.inject_barrier(&b1, [actor_id]);
431        test_env.flush_all_events().await;
432
433        let mut inputs = Vec::with_capacity(3);
434        let mut txs = Vec::with_capacity(3);
435        for _ in 0..3 {
436            let (input, tx) = new_input_for_test(actor_id, test_env.local_barrier_manager.clone());
437            inputs.push(input);
438            txs.push(tx);
439        }
440
441        let sink_union = UpstreamSinkUnionExecutor::for_test(
442            actor_id,
443            test_env.local_barrier_manager.clone(),
444            10,
445        );
446        let mut sink_union = Box::new(sink_union).execute_with_inputs(inputs).boxed();
447
448        for tx in txs {
449            tx.send(MessageInner::Chunk(build_test_chunk(10)).into())
450                .await
451                .unwrap();
452            tx.send(MessageInner::Chunk(build_test_chunk(10)).into())
453                .await
454                .unwrap();
455            tx.send(MessageInner::Barrier(b1.clone().into_dispatcher()).into())
456                .await
457                .unwrap();
458        }
459
460        for _ in 0..6 {
461            let msg = sink_union.next().await.unwrap().unwrap();
462            assert!(msg.is_chunk());
463            assert_eq!(msg.as_chunk().unwrap().ops().len(), 10);
464        }
465
466        let msg = sink_union.next().await.unwrap().unwrap();
467        assert!(msg.is_barrier());
468        let barrier = msg.as_barrier().unwrap();
469        assert_eq!(barrier.epoch, b1.epoch);
470    }
471}