risingwave_batch_executors/executor/
generic_exchange.rs1use std::sync::Arc;
16use std::time::Duration;
17
18use futures::StreamExt;
19use futures_async_stream::try_stream;
20use risingwave_common::array::DataChunk;
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::util::addr::HostAddr;
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_pb::batch_plan::PbExchangeSource;
25use risingwave_pb::batch_plan::plan_node::NodeBody;
26use risingwave_pb::plan_common::Field as NodeField;
27use risingwave_rpc_client::ComputeClientPoolRef;
28use rw_futures_util::select_all;
29
30use crate::error::{BatchError, Result};
31use crate::exchange_source::ExchangeSourceImpl;
32use crate::execution::grpc_exchange::GrpcExchangeSource;
33use crate::execution::local_exchange::LocalExchangeSource;
34use crate::executor::ExecutorBuilder;
35use crate::task::{BatchTaskContext, TaskId};
36
37pub type ExchangeExecutor = GenericExchangeExecutor<DefaultCreateSource>;
38use crate::executor::{BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor};
39use crate::monitor::BatchMetrics;
40
41pub struct GenericExchangeExecutor<CS> {
42 proto_sources: Vec<PbExchangeSource>,
43 source_creators: Vec<CS>,
45 sequential: bool,
46 context: Arc<dyn BatchTaskContext>,
47
48 schema: Schema,
49 #[expect(dead_code)]
50 task_id: TaskId,
51 identity: String,
52
53 metrics: Option<BatchMetrics>,
56}
57
58#[async_trait::async_trait]
60pub trait CreateSource: Send {
61 async fn create_source(
62 &self,
63 context: &dyn BatchTaskContext,
64 prost_source: &PbExchangeSource,
65 ) -> Result<ExchangeSourceImpl>;
66}
67
68#[derive(Clone)]
69pub struct DefaultCreateSource {
70 client_pool: ComputeClientPoolRef,
71}
72
73impl DefaultCreateSource {
74 pub fn new(client_pool: ComputeClientPoolRef) -> Self {
75 Self { client_pool }
76 }
77}
78
79#[async_trait::async_trait]
80impl CreateSource for DefaultCreateSource {
81 async fn create_source(
82 &self,
83 context: &dyn BatchTaskContext,
84 prost_source: &PbExchangeSource,
85 ) -> Result<ExchangeSourceImpl> {
86 let peer_addr = prost_source.get_host()?.into();
87 let task_output_id = prost_source.get_task_output_id()?;
88 let task_id = TaskId::from(task_output_id.get_task_id()?);
89
90 if context.is_local_addr(&peer_addr) && prost_source.local_execute_plan.is_none() {
91 trace!("Exchange locally [{:?}]", task_output_id);
92
93 Ok(ExchangeSourceImpl::Local(LocalExchangeSource::create(
94 task_output_id.try_into()?,
95 context,
96 task_id,
97 )?))
98 } else {
99 trace!(
100 "Exchange remotely from {} [{:?}]",
101 &peer_addr, task_output_id,
102 );
103
104 let mask_failed_serving_worker = || {
105 if let Some(worker_node_manager) = context.worker_node_manager() {
106 if let Some(worker) =
107 worker_node_manager
108 .list_worker_nodes()
109 .iter()
110 .find(|worker| {
111 worker
112 .host
113 .as_ref()
114 .is_some_and(|h| HostAddr::from(h) == peer_addr)
115 && worker.property.as_ref().is_some_and(|p| p.is_serving)
116 })
117 {
118 let duration = Duration::from_secs(std::cmp::max(
119 context.get_config().mask_worker_temporary_secs as u64,
120 1,
121 ));
122 worker_node_manager.mask_worker_node(worker.id, duration);
123 }
124 }
125 };
126
127 Ok(ExchangeSourceImpl::Grpc(
128 GrpcExchangeSource::create(
129 self.client_pool
130 .get_by_addr(peer_addr.clone())
131 .await
132 .inspect_err(|_| mask_failed_serving_worker())?,
133 task_output_id.clone(),
134 prost_source.local_execute_plan.clone(),
135 )
136 .await
137 .inspect_err(|e| {
138 if matches!(e, BatchError::RpcError(_)) {
139 mask_failed_serving_worker()
140 }
141 })?,
142 ))
143 }
144 }
145}
146
147pub struct GenericExchangeExecutorBuilder {}
148
149impl BoxedExecutorBuilder for GenericExchangeExecutorBuilder {
150 async fn new_boxed_executor(
151 source: &ExecutorBuilder<'_>,
152 inputs: Vec<BoxedExecutor>,
153 ) -> Result<BoxedExecutor> {
154 ensure!(
155 inputs.is_empty(),
156 "Exchange executor should not have children!"
157 );
158 let node = try_match_expand!(
159 source.plan_node().get_node_body().unwrap(),
160 NodeBody::Exchange
161 )?;
162
163 let sequential = node.get_sequential();
164
165 ensure!(!node.get_sources().is_empty());
166 let proto_sources: Vec<PbExchangeSource> = node.get_sources().to_vec();
167 let source_creators =
168 vec![DefaultCreateSource::new(source.context().client_pool()); proto_sources.len()];
169
170 let input_schema: Vec<NodeField> = node.get_input_schema().to_vec();
171 let fields = input_schema.iter().map(Field::from).collect::<Vec<Field>>();
172 Ok(Box::new(ExchangeExecutor {
173 proto_sources,
174 source_creators,
175 sequential,
176 context: source.context().clone(),
177 schema: Schema { fields },
178 task_id: source.task_id.clone(),
179 identity: source.plan_node().get_identity().clone(),
180 metrics: source.context().batch_metrics(),
181 }))
182 }
183}
184
185impl<CS: 'static + Send + CreateSource> Executor for GenericExchangeExecutor<CS> {
186 fn schema(&self) -> &Schema {
187 &self.schema
188 }
189
190 fn identity(&self) -> &str {
191 &self.identity
192 }
193
194 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
195 self.do_execute()
196 }
197}
198
199impl<CS: 'static + Send + CreateSource> GenericExchangeExecutor<CS> {
200 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
201 async fn do_execute(self: Box<Self>) {
202 let streams = self
203 .proto_sources
204 .into_iter()
205 .zip_eq_fast(self.source_creators)
206 .map(|(prost_source, source_creator)| {
207 Self::data_chunk_stream(
208 prost_source,
209 source_creator,
210 &*self.context,
211 self.metrics.clone(),
212 )
213 });
214
215 if self.sequential {
216 for mut stream in streams {
217 while let Some(data_chunk) = stream.next().await {
218 let data_chunk = data_chunk?;
219 yield data_chunk
220 }
221 }
222 } else {
223 let mut stream = select_all(streams).boxed();
224 while let Some(data_chunk) = stream.next().await {
225 let data_chunk = data_chunk?;
226 yield data_chunk
227 }
228 }
229 }
230
231 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
232 async fn data_chunk_stream(
233 prost_source: PbExchangeSource,
234 source_creator: CS,
235 context: &dyn BatchTaskContext,
236 metrics: Option<BatchMetrics>,
237 ) {
238 let mut source = source_creator.create_source(context, &prost_source).await?;
239 drop(prost_source);
241 let counter = metrics
243 .as_ref()
244 .map(|metrics| &metrics.executor_metrics().exchange_recv_row_number);
245
246 loop {
247 if let Some(res) = source.take_data().await? {
248 if res.cardinality() == 0 {
249 debug!("Exchange source {:?} output empty chunk.", source);
250 }
251
252 if let Some(counter) = counter {
253 counter.inc_by(res.cardinality().try_into().unwrap());
254 }
255
256 yield res;
257 continue;
258 }
259 break;
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266
267 use rand::Rng;
268 use risingwave_common::array::{Array, I32Array};
269 use risingwave_common::types::DataType;
270
271 use super::*;
272 use crate::executor::test_utils::{FakeCreateSource, FakeExchangeSource};
273 use crate::task::ComputeNodeContext;
274 #[tokio::test]
275 async fn test_exchange_multiple_sources() {
276 let context = ComputeNodeContext::for_test();
277 let mut proto_sources = vec![];
278 let mut source_creators = vec![];
279 for _ in 0..2 {
280 let mut rng = rand::rng();
281 let i = rng.random_range(1..=100000);
282 let chunk = DataChunk::new(vec![I32Array::from_iter([i]).into_ref()], 1);
283 let chunks = vec![Some(chunk); 100];
284 let fake_exchange_source = FakeExchangeSource::new(chunks);
285 let fake_create_source = FakeCreateSource::new(fake_exchange_source);
286 proto_sources.push(PbExchangeSource::default());
287 source_creators.push(fake_create_source);
288 }
289
290 let executor = Box::new(GenericExchangeExecutor::<FakeCreateSource> {
291 metrics: None,
292 proto_sources,
293 source_creators,
294 sequential: false,
295 context,
296 schema: Schema {
297 fields: vec![Field::unnamed(DataType::Int32)],
298 },
299 task_id: TaskId::default(),
300 identity: "GenericExchangeExecutor2".to_owned(),
301 });
302
303 let mut stream = executor.execute();
304 let mut chunks: Vec<DataChunk> = vec![];
305 while let Some(chunk) = stream.next().await {
306 let chunk = chunk.unwrap();
307 chunks.push(chunk);
308 if chunks.len() == 100 {
309 chunks.dedup();
310 assert_ne!(chunks.len(), 1);
311 chunks.clear();
312 }
313 }
314 }
315}