risingwave_batch_executors/executor/
generic_exchange.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use futures::StreamExt;
19use futures_async_stream::try_stream;
20use risingwave_common::array::DataChunk;
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::util::addr::HostAddr;
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_pb::batch_plan::PbExchangeSource;
25use risingwave_pb::batch_plan::plan_node::NodeBody;
26use risingwave_pb::plan_common::Field as NodeField;
27use risingwave_rpc_client::ComputeClientPoolRef;
28use rw_futures_util::select_all;
29
30use crate::error::{BatchError, Result};
31use crate::exchange_source::ExchangeSourceImpl;
32use crate::execution::grpc_exchange::GrpcExchangeSource;
33use crate::execution::local_exchange::LocalExchangeSource;
34use crate::executor::ExecutorBuilder;
35use crate::task::{BatchTaskContext, TaskId};
36
37pub type ExchangeExecutor = GenericExchangeExecutor<DefaultCreateSource>;
38use crate::executor::{BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor};
39use crate::monitor::BatchMetrics;
40
41pub struct GenericExchangeExecutor<CS> {
42    proto_sources: Vec<PbExchangeSource>,
43    /// Mock-able `CreateSource`.
44    source_creators: Vec<CS>,
45    sequential: bool,
46    context: Arc<dyn BatchTaskContext>,
47
48    schema: Schema,
49    #[expect(dead_code)]
50    task_id: TaskId,
51    identity: String,
52
53    /// Batch metrics.
54    /// None: Local mode don't record mertics.
55    metrics: Option<BatchMetrics>,
56}
57
58/// `CreateSource` determines the right type of `ExchangeSource` to create.
59#[async_trait::async_trait]
60pub trait CreateSource: Send {
61    async fn create_source(
62        &self,
63        context: &dyn BatchTaskContext,
64        prost_source: &PbExchangeSource,
65    ) -> Result<ExchangeSourceImpl>;
66}
67
68#[derive(Clone)]
69pub struct DefaultCreateSource {
70    client_pool: ComputeClientPoolRef,
71}
72
73impl DefaultCreateSource {
74    pub fn new(client_pool: ComputeClientPoolRef) -> Self {
75        Self { client_pool }
76    }
77}
78
79#[async_trait::async_trait]
80impl CreateSource for DefaultCreateSource {
81    async fn create_source(
82        &self,
83        context: &dyn BatchTaskContext,
84        prost_source: &PbExchangeSource,
85    ) -> Result<ExchangeSourceImpl> {
86        let peer_addr = prost_source.get_host()?.into();
87        let task_output_id = prost_source.get_task_output_id()?;
88        let task_id = TaskId::from(task_output_id.get_task_id()?);
89
90        if context.is_local_addr(&peer_addr) && prost_source.local_execute_plan.is_none() {
91            trace!("Exchange locally [{:?}]", task_output_id);
92
93            Ok(ExchangeSourceImpl::Local(LocalExchangeSource::create(
94                task_output_id.try_into()?,
95                context,
96                task_id,
97            )?))
98        } else {
99            trace!(
100                "Exchange remotely from {} [{:?}]",
101                &peer_addr, task_output_id,
102            );
103
104            let mask_failed_serving_worker = || {
105                if let Some(worker_node_manager) = context.worker_node_manager() {
106                    if let Some(worker) =
107                        worker_node_manager
108                            .list_worker_nodes()
109                            .iter()
110                            .find(|worker| {
111                                worker
112                                    .host
113                                    .as_ref()
114                                    .is_some_and(|h| HostAddr::from(h) == peer_addr)
115                                    && worker.property.as_ref().is_some_and(|p| p.is_serving)
116                            })
117                    {
118                        let duration = Duration::from_secs(std::cmp::max(
119                            context.get_config().mask_worker_temporary_secs as u64,
120                            1,
121                        ));
122                        worker_node_manager.mask_worker_node(worker.id, duration);
123                    }
124                }
125            };
126
127            Ok(ExchangeSourceImpl::Grpc(
128                GrpcExchangeSource::create(
129                    self.client_pool
130                        .get_by_addr(peer_addr.clone())
131                        .await
132                        .inspect_err(|_| mask_failed_serving_worker())?,
133                    task_output_id.clone(),
134                    prost_source.local_execute_plan.clone(),
135                )
136                .await
137                .inspect_err(|e| {
138                    if matches!(e, BatchError::RpcError(_)) {
139                        mask_failed_serving_worker()
140                    }
141                })?,
142            ))
143        }
144    }
145}
146
147pub struct GenericExchangeExecutorBuilder {}
148
149impl BoxedExecutorBuilder for GenericExchangeExecutorBuilder {
150    async fn new_boxed_executor(
151        source: &ExecutorBuilder<'_>,
152        inputs: Vec<BoxedExecutor>,
153    ) -> Result<BoxedExecutor> {
154        ensure!(
155            inputs.is_empty(),
156            "Exchange executor should not have children!"
157        );
158        let node = try_match_expand!(
159            source.plan_node().get_node_body().unwrap(),
160            NodeBody::Exchange
161        )?;
162
163        let sequential = node.get_sequential();
164
165        ensure!(!node.get_sources().is_empty());
166        let proto_sources: Vec<PbExchangeSource> = node.get_sources().to_vec();
167        let source_creators =
168            vec![DefaultCreateSource::new(source.context().client_pool()); proto_sources.len()];
169
170        let input_schema: Vec<NodeField> = node.get_input_schema().to_vec();
171        let fields = input_schema.iter().map(Field::from).collect::<Vec<Field>>();
172        Ok(Box::new(ExchangeExecutor {
173            proto_sources,
174            source_creators,
175            sequential,
176            context: source.context().clone(),
177            schema: Schema { fields },
178            task_id: source.task_id.clone(),
179            identity: source.plan_node().get_identity().clone(),
180            metrics: source.context().batch_metrics(),
181        }))
182    }
183}
184
185impl<CS: 'static + Send + CreateSource> Executor for GenericExchangeExecutor<CS> {
186    fn schema(&self) -> &Schema {
187        &self.schema
188    }
189
190    fn identity(&self) -> &str {
191        &self.identity
192    }
193
194    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
195        self.do_execute()
196    }
197}
198
199impl<CS: 'static + Send + CreateSource> GenericExchangeExecutor<CS> {
200    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
201    async fn do_execute(self: Box<Self>) {
202        let streams = self
203            .proto_sources
204            .into_iter()
205            .zip_eq_fast(self.source_creators)
206            .map(|(prost_source, source_creator)| {
207                Self::data_chunk_stream(
208                    prost_source,
209                    source_creator,
210                    &*self.context,
211                    self.metrics.clone(),
212                )
213            });
214
215        if self.sequential {
216            for mut stream in streams {
217                while let Some(data_chunk) = stream.next().await {
218                    let data_chunk = data_chunk?;
219                    yield data_chunk
220                }
221            }
222        } else {
223            let mut stream = select_all(streams).boxed();
224            while let Some(data_chunk) = stream.next().await {
225                let data_chunk = data_chunk?;
226                yield data_chunk
227            }
228        }
229    }
230
231    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
232    async fn data_chunk_stream(
233        prost_source: PbExchangeSource,
234        source_creator: CS,
235        context: &dyn BatchTaskContext,
236        metrics: Option<BatchMetrics>,
237    ) {
238        let mut source = source_creator.create_source(context, &prost_source).await?;
239        // Release potential large objects in LocalExecutePlan early.
240        drop(prost_source);
241        // create the collector
242        let counter = metrics
243            .as_ref()
244            .map(|metrics| &metrics.executor_metrics().exchange_recv_row_number);
245
246        loop {
247            if let Some(res) = source.take_data().await? {
248                if res.cardinality() == 0 {
249                    debug!("Exchange source {:?} output empty chunk.", source);
250                }
251
252                if let Some(counter) = counter {
253                    counter.inc_by(res.cardinality().try_into().unwrap());
254                }
255
256                yield res;
257                continue;
258            }
259            break;
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266
267    use rand::Rng;
268    use risingwave_common::array::{Array, I32Array};
269    use risingwave_common::types::DataType;
270
271    use super::*;
272    use crate::executor::test_utils::{FakeCreateSource, FakeExchangeSource};
273    use crate::task::ComputeNodeContext;
274    #[tokio::test]
275    async fn test_exchange_multiple_sources() {
276        let context = ComputeNodeContext::for_test();
277        let mut proto_sources = vec![];
278        let mut source_creators = vec![];
279        for _ in 0..2 {
280            let mut rng = rand::rng();
281            let i = rng.random_range(1..=100000);
282            let chunk = DataChunk::new(vec![I32Array::from_iter([i]).into_ref()], 1);
283            let chunks = vec![Some(chunk); 100];
284            let fake_exchange_source = FakeExchangeSource::new(chunks);
285            let fake_create_source = FakeCreateSource::new(fake_exchange_source);
286            proto_sources.push(PbExchangeSource::default());
287            source_creators.push(fake_create_source);
288        }
289
290        let executor = Box::new(GenericExchangeExecutor::<FakeCreateSource> {
291            metrics: None,
292            proto_sources,
293            source_creators,
294            sequential: false,
295            context,
296            schema: Schema {
297                fields: vec![Field::unnamed(DataType::Int32)],
298            },
299            task_id: TaskId::default(),
300            identity: "GenericExchangeExecutor2".to_owned(),
301        });
302
303        let mut stream = executor.execute();
304        let mut chunks: Vec<DataChunk> = vec![];
305        while let Some(chunk) = stream.next().await {
306            let chunk = chunk.unwrap();
307            chunks.push(chunk);
308            if chunks.len() == 100 {
309                chunks.dedup();
310                assert_ne!(chunks.len(), 1);
311                chunks.clear();
312            }
313        }
314    }
315}