risingwave_batch_executors/executor/
generic_exchange.rs1use std::sync::Arc;
16use std::time::Duration;
17
18use futures::StreamExt;
19use futures_async_stream::try_stream;
20use risingwave_common::array::DataChunk;
21use risingwave_common::catalog::{Field, Schema};
22use risingwave_common::util::addr::HostAddr;
23use risingwave_common::util::iter_util::ZipEqFast;
24use risingwave_pb::batch_plan::PbExchangeSource;
25use risingwave_pb::batch_plan::plan_node::NodeBody;
26use risingwave_pb::plan_common::Field as NodeField;
27use risingwave_rpc_client::ComputeClientPoolRef;
28use rw_futures_util::select_all;
29
30use crate::error::{BatchError, Result};
31use crate::exchange_source::ExchangeSourceImpl;
32use crate::execution::grpc_exchange::GrpcExchangeSource;
33use crate::execution::local_exchange::LocalExchangeSource;
34use crate::executor::ExecutorBuilder;
35use crate::task::{BatchTaskContext, TaskId};
36
37pub type ExchangeExecutor = GenericExchangeExecutor<DefaultCreateSource>;
38use crate::executor::{BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor};
39use crate::monitor::BatchMetrics;
40
41pub struct GenericExchangeExecutor<CS> {
42 proto_sources: Vec<PbExchangeSource>,
43 source_creators: Vec<CS>,
45 sequential: bool,
46 context: Arc<dyn BatchTaskContext>,
47
48 schema: Schema,
49 #[expect(dead_code)]
50 task_id: TaskId,
51 identity: String,
52
53 metrics: Option<BatchMetrics>,
56}
57
58#[async_trait::async_trait]
60pub trait CreateSource: Send {
61 async fn create_source(
62 &self,
63 context: &dyn BatchTaskContext,
64 prost_source: &PbExchangeSource,
65 ) -> Result<ExchangeSourceImpl>;
66}
67
68#[derive(Clone)]
69pub struct DefaultCreateSource {
70 client_pool: ComputeClientPoolRef,
71}
72
73impl DefaultCreateSource {
74 pub fn new(client_pool: ComputeClientPoolRef) -> Self {
75 Self { client_pool }
76 }
77}
78
79#[async_trait::async_trait]
80impl CreateSource for DefaultCreateSource {
81 async fn create_source(
82 &self,
83 context: &dyn BatchTaskContext,
84 prost_source: &PbExchangeSource,
85 ) -> Result<ExchangeSourceImpl> {
86 let peer_addr = prost_source.get_host()?.into();
87 let task_output_id = prost_source.get_task_output_id()?;
88 let task_id = TaskId::from(task_output_id.get_task_id()?);
89
90 if context.is_local_addr(&peer_addr) && prost_source.local_execute_plan.is_none() {
91 trace!("Exchange locally [{:?}]", task_output_id);
92
93 Ok(ExchangeSourceImpl::Local(LocalExchangeSource::create(
94 task_output_id.try_into()?,
95 context,
96 task_id,
97 )?))
98 } else {
99 trace!(
100 "Exchange remotely from {} [{:?}]",
101 &peer_addr, task_output_id,
102 );
103
104 let mask_failed_serving_worker = || {
105 if let Some(worker_node_manager) = context.worker_node_manager()
106 && let Some(worker) =
107 worker_node_manager
108 .list_compute_nodes()
109 .iter()
110 .find(|worker| {
111 worker
112 .host
113 .as_ref()
114 .is_some_and(|h| HostAddr::from(h) == peer_addr)
115 && worker.property.as_ref().is_some_and(|p| p.is_serving)
116 })
117 {
118 let duration = Duration::from_secs(std::cmp::max(
119 context.get_config().mask_worker_temporary_secs as u64,
120 1,
121 ));
122 worker_node_manager.mask_worker_node(worker.id, duration);
123 }
124 };
125
126 Ok(ExchangeSourceImpl::Grpc(
127 GrpcExchangeSource::create(
128 self.client_pool
129 .get_by_addr(peer_addr.clone())
130 .await
131 .inspect_err(|_| mask_failed_serving_worker())?,
132 task_output_id.clone(),
133 prost_source.local_execute_plan.clone(),
134 )
135 .await
136 .inspect_err(|e| {
137 if matches!(e, BatchError::RpcError(_)) {
138 mask_failed_serving_worker()
139 }
140 })?,
141 ))
142 }
143 }
144}
145
146pub struct GenericExchangeExecutorBuilder {}
147
148impl BoxedExecutorBuilder for GenericExchangeExecutorBuilder {
149 async fn new_boxed_executor(
150 source: &ExecutorBuilder<'_>,
151 inputs: Vec<BoxedExecutor>,
152 ) -> Result<BoxedExecutor> {
153 ensure!(
154 inputs.is_empty(),
155 "Exchange executor should not have children!"
156 );
157 let node = try_match_expand!(
158 source.plan_node().get_node_body().unwrap(),
159 NodeBody::Exchange
160 )?;
161
162 let sequential = node.get_sequential();
163
164 ensure!(!node.get_sources().is_empty());
165 let proto_sources: Vec<PbExchangeSource> = node.get_sources().to_vec();
166 let source_creators =
167 vec![DefaultCreateSource::new(source.context().client_pool()); proto_sources.len()];
168
169 let input_schema: Vec<NodeField> = node.get_input_schema().to_vec();
170 let fields = input_schema.iter().map(Field::from).collect::<Vec<Field>>();
171 Ok(Box::new(ExchangeExecutor {
172 proto_sources,
173 source_creators,
174 sequential,
175 context: source.context().clone(),
176 schema: Schema { fields },
177 task_id: source.task_id.clone(),
178 identity: source.plan_node().get_identity().clone(),
179 metrics: source.context().batch_metrics(),
180 }))
181 }
182}
183
184impl<CS: 'static + Send + CreateSource> Executor for GenericExchangeExecutor<CS> {
185 fn schema(&self) -> &Schema {
186 &self.schema
187 }
188
189 fn identity(&self) -> &str {
190 &self.identity
191 }
192
193 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
194 self.do_execute()
195 }
196}
197
198impl<CS: 'static + Send + CreateSource> GenericExchangeExecutor<CS> {
199 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
200 async fn do_execute(self: Box<Self>) {
201 let streams = self
202 .proto_sources
203 .into_iter()
204 .zip_eq_fast(self.source_creators)
205 .map(|(prost_source, source_creator)| {
206 Self::data_chunk_stream(
207 prost_source,
208 source_creator,
209 &*self.context,
210 self.metrics.clone(),
211 )
212 });
213
214 if self.sequential {
215 for mut stream in streams {
216 while let Some(data_chunk) = stream.next().await {
217 let data_chunk = data_chunk?;
218 yield data_chunk
219 }
220 }
221 } else {
222 let mut stream = select_all(streams).boxed();
223 while let Some(data_chunk) = stream.next().await {
224 let data_chunk = data_chunk?;
225 yield data_chunk
226 }
227 }
228 }
229
230 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
231 async fn data_chunk_stream(
232 prost_source: PbExchangeSource,
233 source_creator: CS,
234 context: &dyn BatchTaskContext,
235 metrics: Option<BatchMetrics>,
236 ) {
237 let mut source = source_creator.create_source(context, &prost_source).await?;
238 drop(prost_source);
240 let counter = metrics
242 .as_ref()
243 .map(|metrics| &metrics.executor_metrics().exchange_recv_row_number);
244
245 loop {
246 if let Some(res) = source.take_data().await? {
247 if res.cardinality() == 0 {
248 debug!("Exchange source {:?} output empty chunk.", source);
249 }
250
251 if let Some(counter) = counter {
252 counter.inc_by(res.cardinality().try_into().unwrap());
253 }
254
255 yield res;
256 continue;
257 }
258 break;
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265
266 use rand::Rng;
267 use risingwave_common::array::{Array, I32Array};
268 use risingwave_common::types::DataType;
269
270 use super::*;
271 use crate::executor::test_utils::{FakeCreateSource, FakeExchangeSource};
272 use crate::task::ComputeNodeContext;
273 #[tokio::test]
274 async fn test_exchange_multiple_sources() {
275 let context = ComputeNodeContext::for_test();
276 let mut proto_sources = vec![];
277 let mut source_creators = vec![];
278 for _ in 0..2 {
279 let mut rng = rand::rng();
280 let i = rng.random_range(1..=100000);
281 let chunk = DataChunk::new(vec![I32Array::from_iter([i]).into_ref()], 1);
282 let chunks = vec![Some(chunk); 100];
283 let fake_exchange_source = FakeExchangeSource::new(chunks);
284 let fake_create_source = FakeCreateSource::new(fake_exchange_source);
285 proto_sources.push(PbExchangeSource::default());
286 source_creators.push(fake_create_source);
287 }
288
289 let executor = Box::new(GenericExchangeExecutor::<FakeCreateSource> {
290 metrics: None,
291 proto_sources,
292 source_creators,
293 sequential: false,
294 context,
295 schema: Schema {
296 fields: vec![Field::unnamed(DataType::Int32)],
297 },
298 task_id: TaskId::default(),
299 identity: "GenericExchangeExecutor2".to_owned(),
300 });
301
302 let mut stream = executor.execute();
303 let mut chunks: Vec<DataChunk> = vec![];
304 while let Some(chunk) = stream.next().await {
305 let chunk = chunk.unwrap();
306 chunks.push(chunk);
307 if chunks.len() == 100 {
308 chunks.dedup();
309 assert_ne!(chunks.len(), 1);
310 chunks.clear();
311 }
312 }
313 }
314}