risingwave_stream/executor/exchange/
permit.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Channel implementation for permit-based back-pressure.
16
17use std::sync::Arc;
18
19use risingwave_common::config::StreamingConfig;
20use risingwave_pb::task_service::permits;
21use tokio::sync::{AcquireError, Semaphore, SemaphorePermit, mpsc};
22
23use crate::executor::DispatcherMessageBatch as Message;
24
25/// Message with its required permits.
26///
27/// We store the `permits` in the struct instead of implying it from the `message` so that the
28/// permit number is totally determined by the sender and the downstream only needs to give the
29/// `permits` back verbatim, in case the version of the upstream and the downstream are different.
30pub struct MessageWithPermits {
31    pub message: Message,
32    pub permits: Option<permits::Value>,
33}
34
35/// Create a channel for the exchange service.
36pub fn channel(
37    initial_permits: usize,
38    batched_permits: usize,
39    concurrent_barriers: usize,
40) -> (Sender, Receiver) {
41    // Use an unbounded channel since we manage the permits manually.
42    let (tx, rx) = mpsc::unbounded_channel();
43
44    let records = Semaphore::new(initial_permits);
45    let barriers = Semaphore::new(concurrent_barriers);
46    let permits = Arc::new(Permits { records, barriers });
47
48    let max_chunk_permits: usize = initial_permits - batched_permits;
49
50    (
51        Sender {
52            tx,
53            permits: permits.clone(),
54            max_chunk_permits,
55        },
56        Receiver { rx, permits },
57    )
58}
59
60pub fn channel_from_config(config: &StreamingConfig) -> (Sender, Receiver) {
61    channel(
62        config.developer.exchange_initial_permits,
63        config.developer.exchange_batched_permits,
64        config.developer.exchange_concurrent_barriers,
65    )
66}
67
68/// The configuration for tests.
69pub mod for_test {
70    pub const INITIAL_PERMITS: usize = (u32::MAX / 2) as _;
71    pub const BATCHED_PERMITS: usize = 1;
72    pub const CONCURRENT_BARRIERS: usize = (u32::MAX / 2) as _;
73}
74
75pub fn channel_for_test() -> (Sender, Receiver) {
76    use for_test::*;
77
78    channel(INITIAL_PERMITS, BATCHED_PERMITS, CONCURRENT_BARRIERS)
79}
80
81/// Semaphore-based permits to control the back-pressure.
82///
83/// The number of messages in the exchange channel is limited by these semaphores.
84pub struct Permits {
85    /// The permits for records in chunks.
86    records: Semaphore,
87    /// The permits for barriers.
88    barriers: Semaphore,
89}
90
91impl Permits {
92    /// Add permits back to the semaphores.
93    pub fn add_permits(&self, permits: permits::Value) {
94        match permits {
95            permits::Value::Record(p) => self.records.add_permits(p as usize),
96            permits::Value::Barrier(p) => self.barriers.add_permits(p as usize),
97        }
98    }
99
100    /// Acquire permits from the semaphores.
101    ///
102    /// This function is cancellation-safe except for the fairness of waking.
103    async fn acquire_permits(&self, permits: &permits::Value) -> Result<(), AcquireError> {
104        match permits {
105            permits::Value::Record(p) => self.records.acquire_many(*p as _),
106            permits::Value::Barrier(p) => self.barriers.acquire_many(*p as _),
107        }
108        .await
109        .map(SemaphorePermit::forget)
110    }
111
112    /// Close the semaphores so that all pending `acquire` will fail immediately.
113    fn close(&self) {
114        self.records.close();
115        self.barriers.close();
116    }
117}
118
119/// The sender of the exchange service with permit-based back-pressure.
120pub struct Sender {
121    tx: mpsc::UnboundedSender<MessageWithPermits>,
122    permits: Arc<Permits>,
123
124    /// The maximum permits required by a chunk. If there're too many rows in a chunk, we only
125    /// acquire these permits. `BATCHED_PERMITS` is subtracted to avoid deadlock with
126    /// batching.
127    max_chunk_permits: usize,
128}
129
130impl Sender {
131    /// Send a message, waiting until there are enough permits.
132    ///
133    /// Returns error if the receive half of the channel is closed, including the message passed.
134    pub async fn send(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
135        // The semaphores should never be closed.
136        let permits = match &message {
137            Message::Chunk(c) => {
138                let card = c.cardinality().clamp(1, self.max_chunk_permits);
139                if card == self.max_chunk_permits {
140                    tracing::warn!(cardinality = c.cardinality(), "large chunk in exchange")
141                }
142                Some(permits::Value::Record(card as _))
143            }
144            Message::BarrierBatch(_) => Some(permits::Value::Barrier(1)),
145            Message::Watermark(_) => None,
146        };
147
148        if let Some(permits) = &permits {
149            if self.permits.acquire_permits(permits).await.is_err() {
150                return Err(mpsc::error::SendError(message));
151            }
152        }
153
154        self.tx
155            .send(MessageWithPermits { message, permits })
156            .map_err(|e| mpsc::error::SendError(e.0.message))
157    }
158}
159
160/// The receiver of the exchange service with permit-based back-pressure.
161pub struct Receiver {
162    rx: mpsc::UnboundedReceiver<MessageWithPermits>,
163    permits: Arc<Permits>,
164}
165
166impl Receiver {
167    /// Receive the next message for this receiver, with the permits of this message added back.
168    /// Used for local exchange.
169    ///
170    /// Returns `None` if the channel has been closed.
171    pub async fn recv(&mut self) -> Option<Message> {
172        let MessageWithPermits { message, permits } = self.recv_raw().await?;
173
174        if let Some(permits) = permits {
175            self.permits.add_permits(permits);
176        }
177
178        Some(message)
179    }
180
181    /// Try to receive the next message for this receiver, with the permits of this message added
182    /// back.
183    ///
184    /// Returns error if the channel is currently empty.
185    pub fn try_recv(&mut self) -> Result<Message, mpsc::error::TryRecvError> {
186        let MessageWithPermits { message, permits } = self.rx.try_recv()?;
187
188        if let Some(permits) = permits {
189            self.permits.add_permits(permits);
190        }
191
192        Ok(message)
193    }
194
195    /// Receive the next message and its permits for this receiver, **without** adding the permits
196    /// back. Used for remote exchange where the permits should be manually added according to the
197    /// downstream actor.
198    ///
199    /// Returns `None` if the channel has been closed.
200    pub async fn recv_raw(&mut self) -> Option<MessageWithPermits> {
201        self.rx.recv().await
202    }
203
204    /// Get a reference to the inner [`Permits`] to manually add permits.
205    pub fn permits(&self) -> Arc<Permits> {
206        self.permits.clone()
207    }
208}
209
210impl Drop for Receiver {
211    fn drop(&mut self) {
212        // Close the `permits` semaphores so that all pending `acquire` on the sender side will fail
213        // immediately.
214        self.permits.close();
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use std::assert_matches::assert_matches;
221    use std::pin::pin;
222
223    use futures::FutureExt;
224
225    use super::*;
226    use crate::executor::DispatcherBarrier as Barrier;
227
228    #[test]
229    fn test_channel_close() {
230        let (tx, mut rx) = channel(0, 0, 1);
231
232        let send = || {
233            tx.send(Message::BarrierBatch(vec![
234                Barrier::with_prev_epoch_for_test(514, 114),
235            ]))
236        };
237
238        assert_matches!(send().now_or_never(), Some(Ok(_))); // send successfully
239        assert_matches!(
240            rx.recv().now_or_never(),
241            Some(Some(Message::BarrierBatch(_)))
242        ); // recv successfully
243
244        assert_matches!(send().now_or_never(), Some(Ok(_))); // send successfully
245        // do not recv, so that the channel is full
246
247        let mut send_fut = pin!(send());
248        assert_matches!((&mut send_fut).now_or_never(), None); // would block due to no permits
249        drop(rx);
250        assert_matches!(send_fut.now_or_never(), Some(Err(_))); // channel closed
251    }
252}