risingwave_stream/executor/exchange/
input.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use either::Either;
19use local_input::LocalInputStreamInner;
20use pin_project::pin_project;
21use risingwave_common::config::StreamingConfig;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23
24use super::permit::Receiver;
25use crate::executor::prelude::*;
26use crate::executor::{
27    BarrierInner, DispatcherMessage, DispatcherMessageBatch, DispatcherMessageStreamItem,
28};
29use crate::task::{FragmentId, LocalBarrierManager, UpDownActorIds, UpDownFragmentIds};
30
31/// `Input` is a more abstract upstream input type, used for `DynamicReceivers` type
32/// handling of multiple upstream inputs
33pub trait Input: Stream + Send {
34    type InputId;
35    /// The upstream input id.
36    fn id(&self) -> Self::InputId;
37
38    fn boxed_input(self) -> BoxedInput<Self::InputId, Self::Item>
39    where
40        Self: Sized + 'static,
41    {
42        Box::pin(self)
43    }
44}
45
46pub type BoxedInput<InputId, Item> = Pin<Box<dyn Input<InputId = InputId, Item = Item>>>;
47
48/// `ActorInput` provides an interface for [`MergeExecutor`](crate::executor::MergeExecutor) and
49/// [`ReceiverExecutor`](crate::executor::ReceiverExecutor) to receive data from upstream actors.
50/// Only used for actor inputs.
51pub trait ActorInput = Input<Item = DispatcherMessageStreamItem, InputId = ActorId>;
52
53pub type BoxedActorInput = Pin<Box<dyn ActorInput>>;
54
55impl std::fmt::Debug for dyn ActorInput {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("Input")
58            .field("actor_id", &self.id())
59            .finish_non_exhaustive()
60    }
61}
62
63/// `LocalInput` receives data from a local channel.
64#[pin_project]
65pub struct LocalInput {
66    #[pin]
67    inner: LocalInputStreamInner,
68
69    actor_id: ActorId,
70}
71
72pub(crate) fn assert_equal_dispatcher_barrier<M1, M2>(
73    first: &BarrierInner<M1>,
74    second: &BarrierInner<M2>,
75) {
76    assert_eq!(first.epoch, second.epoch);
77    assert_eq!(first.kind, second.kind);
78}
79
80impl LocalInput {
81    pub fn new(channel: Receiver, upstream_actor_id: ActorId) -> Self {
82        Self {
83            inner: local_input::run(channel, upstream_actor_id),
84            actor_id: upstream_actor_id,
85        }
86    }
87}
88
89mod local_input {
90    use await_tree::InstrumentAwait;
91    use either::Either;
92
93    use crate::executor::exchange::error::ExchangeChannelClosed;
94    use crate::executor::exchange::permit::Receiver;
95    use crate::executor::prelude::try_stream;
96    use crate::executor::{DispatcherMessage, StreamExecutorError};
97    use crate::task::ActorId;
98
99    pub(super) type LocalInputStreamInner = impl crate::executor::DispatcherMessageStream;
100
101    #[define_opaque(LocalInputStreamInner)]
102    pub(super) fn run(channel: Receiver, upstream_actor_id: ActorId) -> LocalInputStreamInner {
103        run_inner(channel, upstream_actor_id)
104    }
105
106    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
107    async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) {
108        let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose();
109        while let Some(msg) = channel.recv().instrument_await(span.clone()).await {
110            match msg.into_messages() {
111                Either::Left(barriers) => {
112                    for b in barriers {
113                        yield b;
114                    }
115                }
116                Either::Right(m) => {
117                    yield m;
118                }
119            }
120        }
121        // Always emit an error outside the loop. This is because we use barrier as the control
122        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
123        Err(ExchangeChannelClosed::local_input(upstream_actor_id))?
124    }
125}
126
127impl Stream for LocalInput {
128    type Item = DispatcherMessageStreamItem;
129
130    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
131        // TODO: shall we pass the error with local exchange?
132        self.project().inner.poll_next(cx)
133    }
134}
135
136impl Input for LocalInput {
137    type InputId = ActorId;
138
139    fn id(&self) -> Self::InputId {
140        self.actor_id
141    }
142}
143
144/// `RemoteInput` connects to the upstream exchange server and receives data with `gRPC`.
145#[pin_project]
146pub struct RemoteInput {
147    #[pin]
148    inner: RemoteInputStreamInner,
149
150    actor_id: ActorId,
151}
152
153use remote_input::RemoteInputStreamInner;
154use risingwave_pb::common::ActorInfo;
155use risingwave_pb::id::PartialGraphId;
156
157impl RemoteInput {
158    /// Create a remote input from compute client and related info. Should provide the corresponding
159    /// compute client of where the actor is placed.
160    pub async fn new(
161        local_barrier_manager: &LocalBarrierManager,
162        upstream_addr: HostAddr,
163        upstream_partial_graph_id: PartialGraphId,
164        up_down_ids: UpDownActorIds,
165        up_down_frag: UpDownFragmentIds,
166        metrics: Arc<StreamingMetrics>,
167        actor_config: Arc<StreamingConfig>,
168    ) -> StreamExecutorResult<Self> {
169        let actor_id = up_down_ids.0;
170
171        let client = local_barrier_manager
172            .env
173            .client_pool()
174            .get_by_addr(upstream_addr)
175            .await?;
176        let (stream, permits_tx) = client
177            .get_stream(
178                up_down_ids.0,
179                up_down_ids.1,
180                up_down_frag.0,
181                up_down_frag.1,
182                upstream_partial_graph_id,
183                local_barrier_manager.term_id.clone(),
184            )
185            .await?;
186
187        Ok(Self {
188            actor_id,
189            inner: remote_input::run(
190                stream,
191                permits_tx,
192                up_down_ids,
193                up_down_frag,
194                metrics,
195                actor_config.developer.exchange_batched_permits,
196            ),
197        })
198    }
199}
200
201mod remote_input {
202    use std::sync::Arc;
203
204    use anyhow::Context;
205    use await_tree::InstrumentAwait;
206    use either::Either;
207    use risingwave_pb::task_service::{GetStreamResponse, permits};
208    use tokio::sync::mpsc;
209    use tonic::Streaming;
210
211    use crate::executor::exchange::error::ExchangeChannelClosed;
212    use crate::executor::monitor::StreamingMetrics;
213    use crate::executor::prelude::{StreamExt, pin_mut, try_stream};
214    use crate::executor::{DispatcherMessage, StreamExecutorError};
215    use crate::task::{UpDownActorIds, UpDownFragmentIds};
216
217    pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream;
218
219    #[define_opaque(RemoteInputStreamInner)]
220    pub(super) fn run(
221        stream: Streaming<GetStreamResponse>,
222        permits_tx: mpsc::UnboundedSender<permits::Value>,
223        up_down_ids: UpDownActorIds,
224        up_down_frag: UpDownFragmentIds,
225        metrics: Arc<StreamingMetrics>,
226        batched_permits_limit: usize,
227    ) -> RemoteInputStreamInner {
228        run_inner(
229            stream,
230            permits_tx,
231            up_down_ids,
232            up_down_frag,
233            metrics,
234            batched_permits_limit,
235        )
236    }
237
238    #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)]
239    async fn run_inner(
240        stream: Streaming<GetStreamResponse>,
241        permits_tx: mpsc::UnboundedSender<permits::Value>,
242        up_down_ids: UpDownActorIds,
243        up_down_frag: UpDownFragmentIds,
244        metrics: Arc<StreamingMetrics>,
245        batched_permits_limit: usize,
246    ) {
247        let up_actor_id = up_down_ids.0.to_string();
248        let up_fragment_id = up_down_frag.0.to_string();
249        let down_fragment_id = up_down_frag.1.to_string();
250        let exchange_frag_recv_size_metrics = metrics
251            .exchange_frag_recv_size
252            .with_guarded_label_values(&[&up_fragment_id, &down_fragment_id]);
253
254        let span = await_tree::span!("RemoteInput (actor {up_actor_id})").verbose();
255
256        let mut batched_permits_accumulated = 0;
257
258        pin_mut!(stream);
259        while let Some(data_res) = stream.next().instrument_await(span.clone()).await {
260            match data_res {
261                Ok(GetStreamResponse { message, permits }) => {
262                    use crate::executor::DispatcherMessageBatch;
263                    let msg = message.unwrap();
264                    let bytes = DispatcherMessageBatch::get_encoded_len(&msg);
265
266                    exchange_frag_recv_size_metrics.inc_by(bytes as u64);
267
268                    let msg_res = DispatcherMessageBatch::from_protobuf(&msg);
269                    if let Some(add_back_permits) = match permits.unwrap().value {
270                        // For records, batch the permits we received to reduce the backward
271                        // `AddPermits` messages.
272                        Some(permits::Value::Record(p)) => {
273                            batched_permits_accumulated += p;
274                            if batched_permits_accumulated >= batched_permits_limit as u32 {
275                                let permits = std::mem::take(&mut batched_permits_accumulated);
276                                Some(permits::Value::Record(permits))
277                            } else {
278                                None
279                            }
280                        }
281                        // For barriers, always send it back immediately.
282                        Some(permits::Value::Barrier(p)) => Some(permits::Value::Barrier(p)),
283                        None => None,
284                    } {
285                        permits_tx
286                            .send(add_back_permits)
287                            .context("RemoteInput backward permits channel closed.")?;
288                    }
289
290                    let msg = msg_res.context("RemoteInput decode message error")?;
291                    match msg.into_messages() {
292                        Either::Left(barriers) => {
293                            for b in barriers {
294                                yield b;
295                            }
296                        }
297                        Either::Right(m) => {
298                            yield m;
299                        }
300                    }
301                }
302
303                Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?,
304            }
305        }
306
307        // Always emit an error outside the loop. This is because we use barrier as the control
308        // message to stop the stream. Reaching here means the channel is closed unexpectedly.
309        Err(ExchangeChannelClosed::remote_input(up_down_ids.0, None))?
310    }
311}
312
313impl Stream for RemoteInput {
314    type Item = DispatcherMessageStreamItem;
315
316    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
317        self.project().inner.poll_next(cx)
318    }
319}
320
321impl Input for RemoteInput {
322    type InputId = ActorId;
323
324    fn id(&self) -> Self::InputId {
325        self.actor_id
326    }
327}
328
329/// Create a [`LocalInput`] or [`RemoteInput`] instance with given info. Used by merge executors and
330/// receiver executors.
331pub(crate) async fn new_input(
332    local_barrier_manager: &LocalBarrierManager,
333    metrics: Arc<StreamingMetrics>,
334    actor_id: ActorId,
335    fragment_id: FragmentId,
336    upstream_actor_info: &ActorInfo,
337    upstream_fragment_id: FragmentId,
338    actor_config: Arc<StreamingConfig>,
339) -> StreamExecutorResult<BoxedActorInput> {
340    let upstream_actor_id = upstream_actor_info.actor_id;
341    let upstream_addr = upstream_actor_info.get_host()?.into();
342
343    let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) {
344        LocalInput::new(
345            local_barrier_manager.register_local_upstream_output(
346                actor_id,
347                upstream_actor_id,
348                upstream_actor_info.partial_graph_id,
349            ),
350            upstream_actor_id,
351        )
352        .boxed_input()
353    } else {
354        RemoteInput::new(
355            local_barrier_manager,
356            upstream_addr,
357            upstream_actor_info.partial_graph_id,
358            (upstream_actor_id, actor_id),
359            (upstream_fragment_id, fragment_id),
360            metrics,
361            actor_config,
362        )
363        .await?
364        .boxed_input()
365    };
366
367    Ok(input)
368}
369
370impl DispatcherMessageBatch {
371    fn into_messages(self) -> Either<impl Iterator<Item = DispatcherMessage>, DispatcherMessage> {
372        match self {
373            DispatcherMessageBatch::BarrierBatch(barriers) => {
374                Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier))
375            }
376            DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)),
377            DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)),
378        }
379    }
380}