risingwave_batch_executors/executor/
generic_exchange.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use futures::StreamExt;
19use futures_async_stream::try_stream;
20use risingwave_common::array::DataChunk;
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::util::addr::HostAddr;
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_pb::batch_plan::PbExchangeSource;
25use risingwave_pb::batch_plan::plan_node::NodeBody;
26use risingwave_pb::plan_common::Field as NodeField;
27use risingwave_rpc_client::ComputeClientPoolRef;
28use rw_futures_util::select_all;
29
30use crate::error::{BatchError, Result};
31use crate::exchange_source::ExchangeSourceImpl;
32use crate::execution::grpc_exchange::GrpcExchangeSource;
33use crate::execution::local_exchange::LocalExchangeSource;
34use crate::executor::ExecutorBuilder;
35use crate::task::{BatchTaskContext, TaskId};
36
37pub type ExchangeExecutor = GenericExchangeExecutor<DefaultCreateSource>;
38use crate::executor::{BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor};
39use crate::monitor::BatchMetrics;
40
41pub struct GenericExchangeExecutor<CS> {
42    proto_sources: Vec<PbExchangeSource>,
43    /// Mock-able `CreateSource`.
44    source_creators: Vec<CS>,
45    sequential: bool,
46    context: Arc<dyn BatchTaskContext>,
47
48    schema: Schema,
49    #[expect(dead_code)]
50    task_id: TaskId,
51    identity: String,
52
53    /// Batch metrics.
54    /// None: Local mode don't record metrics.
55    metrics: Option<BatchMetrics>,
56}
57
58/// `CreateSource` determines the right type of `ExchangeSource` to create.
59#[async_trait::async_trait]
60pub trait CreateSource: Send {
61    async fn create_source(
62        &self,
63        context: &dyn BatchTaskContext,
64        prost_source: &PbExchangeSource,
65    ) -> Result<ExchangeSourceImpl>;
66}
67
68#[derive(Clone)]
69pub struct DefaultCreateSource {
70    client_pool: ComputeClientPoolRef,
71}
72
73impl DefaultCreateSource {
74    pub fn new(client_pool: ComputeClientPoolRef) -> Self {
75        Self { client_pool }
76    }
77}
78
79#[async_trait::async_trait]
80impl CreateSource for DefaultCreateSource {
81    async fn create_source(
82        &self,
83        context: &dyn BatchTaskContext,
84        prost_source: &PbExchangeSource,
85    ) -> Result<ExchangeSourceImpl> {
86        let peer_addr = prost_source.get_host()?.into();
87        let task_output_id = prost_source.get_task_output_id()?;
88        let task_id = TaskId::from(task_output_id.get_task_id()?);
89
90        if context.is_local_addr(&peer_addr) && prost_source.local_execute_plan.is_none() {
91            trace!("Exchange locally [{:?}]", task_output_id);
92
93            Ok(ExchangeSourceImpl::Local(LocalExchangeSource::create(
94                task_output_id.try_into()?,
95                context,
96                task_id,
97            )?))
98        } else {
99            trace!(
100                "Exchange remotely from {} [{:?}]",
101                &peer_addr, task_output_id,
102            );
103
104            let mask_failed_serving_worker = || {
105                if let Some(worker_node_manager) = context.worker_node_manager()
106                    && let Some(worker) =
107                        worker_node_manager
108                            .list_compute_nodes()
109                            .iter()
110                            .find(|worker| {
111                                worker
112                                    .host
113                                    .as_ref()
114                                    .is_some_and(|h| HostAddr::from(h) == peer_addr)
115                                    && worker.property.as_ref().is_some_and(|p| p.is_serving)
116                            })
117                {
118                    let duration = Duration::from_secs(std::cmp::max(
119                        context.get_config().mask_worker_temporary_secs as u64,
120                        1,
121                    ));
122                    worker_node_manager.mask_worker_node(worker.id, duration);
123                }
124            };
125
126            Ok(ExchangeSourceImpl::Grpc(
127                GrpcExchangeSource::create(
128                    self.client_pool
129                        .get_by_addr(peer_addr.clone())
130                        .await
131                        .inspect_err(|_| mask_failed_serving_worker())?,
132                    task_output_id.clone(),
133                    prost_source.local_execute_plan.clone(),
134                )
135                .await
136                .inspect_err(|e| {
137                    if matches!(e, BatchError::RpcError(_)) {
138                        mask_failed_serving_worker()
139                    }
140                })?,
141            ))
142        }
143    }
144}
145
146pub struct GenericExchangeExecutorBuilder {}
147
148impl BoxedExecutorBuilder for GenericExchangeExecutorBuilder {
149    async fn new_boxed_executor(
150        source: &ExecutorBuilder<'_>,
151        inputs: Vec<BoxedExecutor>,
152    ) -> Result<BoxedExecutor> {
153        ensure!(
154            inputs.is_empty(),
155            "Exchange executor should not have children!"
156        );
157        let node = try_match_expand!(
158            source.plan_node().get_node_body().unwrap(),
159            NodeBody::Exchange
160        )?;
161
162        let sequential = node.get_sequential();
163
164        ensure!(!node.get_sources().is_empty());
165        let proto_sources: Vec<PbExchangeSource> = node.get_sources().to_vec();
166        let source_creators =
167            vec![DefaultCreateSource::new(source.context().client_pool()); proto_sources.len()];
168
169        let input_schema: Vec<NodeField> = node.get_input_schema().to_vec();
170        let fields = input_schema.iter().map(Field::from).collect::<Vec<Field>>();
171        Ok(Box::new(ExchangeExecutor {
172            proto_sources,
173            source_creators,
174            sequential,
175            context: source.context().clone(),
176            schema: Schema { fields },
177            task_id: source.task_id.clone(),
178            identity: source.plan_node().get_identity().clone(),
179            metrics: source.context().batch_metrics(),
180        }))
181    }
182}
183
184impl<CS: 'static + Send + CreateSource> Executor for GenericExchangeExecutor<CS> {
185    fn schema(&self) -> &Schema {
186        &self.schema
187    }
188
189    fn identity(&self) -> &str {
190        &self.identity
191    }
192
193    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
194        self.do_execute()
195    }
196}
197
198impl<CS: 'static + Send + CreateSource> GenericExchangeExecutor<CS> {
199    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
200    async fn do_execute(self: Box<Self>) {
201        let streams = self
202            .proto_sources
203            .into_iter()
204            .zip_eq_fast(self.source_creators)
205            .map(|(prost_source, source_creator)| {
206                Self::data_chunk_stream(
207                    prost_source,
208                    source_creator,
209                    &*self.context,
210                    self.metrics.clone(),
211                )
212            });
213
214        if self.sequential {
215            for mut stream in streams {
216                while let Some(data_chunk) = stream.next().await {
217                    let data_chunk = data_chunk?;
218                    yield data_chunk
219                }
220            }
221        } else {
222            let mut stream = select_all(streams).boxed();
223            while let Some(data_chunk) = stream.next().await {
224                let data_chunk = data_chunk?;
225                yield data_chunk
226            }
227        }
228    }
229
230    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
231    async fn data_chunk_stream(
232        prost_source: PbExchangeSource,
233        source_creator: CS,
234        context: &dyn BatchTaskContext,
235        metrics: Option<BatchMetrics>,
236    ) {
237        let mut source = source_creator.create_source(context, &prost_source).await?;
238        // Release potential large objects in LocalExecutePlan early.
239        drop(prost_source);
240        // create the collector
241        let counter = metrics
242            .as_ref()
243            .map(|metrics| &metrics.executor_metrics().exchange_recv_row_number);
244
245        loop {
246            if let Some(res) = source.take_data().await? {
247                if res.cardinality() == 0 {
248                    debug!("Exchange source {:?} output empty chunk.", source);
249                }
250
251                if let Some(counter) = counter {
252                    counter.inc_by(res.cardinality().try_into().unwrap());
253                }
254
255                yield res;
256                continue;
257            }
258            break;
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265
266    use rand::Rng;
267    use risingwave_common::array::{Array, I32Array};
268    use risingwave_common::types::DataType;
269
270    use super::*;
271    use crate::executor::test_utils::{FakeCreateSource, FakeExchangeSource};
272    use crate::task::ComputeNodeContext;
273    #[tokio::test]
274    async fn test_exchange_multiple_sources() {
275        let context = ComputeNodeContext::for_test();
276        let mut proto_sources = vec![];
277        let mut source_creators = vec![];
278        for _ in 0..2 {
279            let mut rng = rand::rng();
280            let i = rng.random_range(1..=100000);
281            let chunk = DataChunk::new(vec![I32Array::from_iter([i]).into_ref()], 1);
282            let chunks = vec![Some(chunk); 100];
283            let fake_exchange_source = FakeExchangeSource::new(chunks);
284            let fake_create_source = FakeCreateSource::new(fake_exchange_source);
285            proto_sources.push(PbExchangeSource::default());
286            source_creators.push(fake_create_source);
287        }
288
289        let executor = Box::new(GenericExchangeExecutor::<FakeCreateSource> {
290            metrics: None,
291            proto_sources,
292            source_creators,
293            sequential: false,
294            context,
295            schema: Schema {
296                fields: vec![Field::unnamed(DataType::Int32)],
297            },
298            task_id: TaskId::default(),
299            identity: "GenericExchangeExecutor2".to_owned(),
300        });
301
302        let mut stream = executor.execute();
303        let mut chunks: Vec<DataChunk> = vec![];
304        while let Some(chunk) = stream.next().await {
305            let chunk = chunk.unwrap();
306            chunks.push(chunk);
307            if chunks.len() == 100 {
308                chunks.dedup();
309                assert_ne!(chunks.len(), 1);
310                chunks.clear();
311            }
312        }
313    }
314}