risingwave_stream/executor/exchange/
permit.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Channel implementation for permit-based back-pressure.
16
17use std::sync::Arc;
18
19use risingwave_pb::task_service::permits;
20use tokio::sync::{AcquireError, Semaphore, SemaphorePermit, mpsc};
21
22use crate::executor::DispatcherMessageBatch as Message;
23
24/// Message with its required permits.
25///
26/// We store the `permits` in the struct instead of implying it from the `message` so that the
27/// permit number is totally determined by the sender and the downstream only needs to give the
28/// `permits` back verbatim, in case the version of the upstream and the downstream are different.
29pub struct MessageWithPermits {
30    pub message: Message,
31    pub permits: Option<permits::Value>,
32}
33
34/// Create a channel for the exchange service.
35pub fn channel(
36    initial_permits: usize,
37    batched_permits: usize,
38    concurrent_barriers: usize,
39) -> (Sender, Receiver) {
40    // Use an unbounded channel since we manage the permits manually.
41    let (tx, rx) = mpsc::unbounded_channel();
42
43    let records = Semaphore::new(initial_permits);
44    let barriers = Semaphore::new(concurrent_barriers);
45    let permits = Arc::new(Permits { records, barriers });
46
47    let max_chunk_permits: usize = initial_permits - batched_permits;
48
49    (
50        Sender {
51            tx,
52            permits: permits.clone(),
53            max_chunk_permits,
54        },
55        Receiver { rx, permits },
56    )
57}
58
59/// The configuration for tests.
60pub mod for_test {
61    pub const INITIAL_PERMITS: usize = (u32::MAX / 2) as _;
62    pub const BATCHED_PERMITS: usize = 1;
63    pub const CONCURRENT_BARRIERS: usize = (u32::MAX / 2) as _;
64}
65
66pub fn channel_for_test() -> (Sender, Receiver) {
67    use for_test::*;
68
69    channel(INITIAL_PERMITS, BATCHED_PERMITS, CONCURRENT_BARRIERS)
70}
71
72/// Semaphore-based permits to control the back-pressure.
73///
74/// The number of messages in the exchange channel is limited by these semaphores.
75pub struct Permits {
76    /// The permits for records in chunks.
77    records: Semaphore,
78    /// The permits for barriers.
79    barriers: Semaphore,
80}
81
82impl Permits {
83    /// Add permits back to the semaphores.
84    pub fn add_permits(&self, permits: permits::Value) {
85        match permits {
86            permits::Value::Record(p) => self.records.add_permits(p as usize),
87            permits::Value::Barrier(p) => self.barriers.add_permits(p as usize),
88        }
89    }
90
91    /// Acquire permits from the semaphores.
92    ///
93    /// This function is cancellation-safe except for the fairness of waking.
94    async fn acquire_permits(&self, permits: &permits::Value) -> Result<(), AcquireError> {
95        match permits {
96            permits::Value::Record(p) => self.records.acquire_many(*p as _),
97            permits::Value::Barrier(p) => self.barriers.acquire_many(*p as _),
98        }
99        .await
100        .map(SemaphorePermit::forget)
101    }
102
103    /// Close the semaphores so that all pending `acquire` will fail immediately.
104    fn close(&self) {
105        self.records.close();
106        self.barriers.close();
107    }
108}
109
110/// The sender of the exchange service with permit-based back-pressure.
111pub struct Sender {
112    tx: mpsc::UnboundedSender<MessageWithPermits>,
113    permits: Arc<Permits>,
114
115    /// The maximum permits required by a chunk. If there're too many rows in a chunk, we only
116    /// acquire these permits. `BATCHED_PERMITS` is subtracted to avoid deadlock with
117    /// batching.
118    max_chunk_permits: usize,
119}
120
121impl Sender {
122    /// Send a message, waiting until there are enough permits.
123    ///
124    /// Returns error if the receive half of the channel is closed, including the message passed.
125    pub async fn send(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
126        // The semaphores should never be closed.
127        let permits = match &message {
128            Message::Chunk(c) => {
129                let card = c.cardinality().clamp(1, self.max_chunk_permits);
130                if card == self.max_chunk_permits {
131                    tracing::warn!(cardinality = c.cardinality(), "large chunk in exchange")
132                }
133                Some(permits::Value::Record(card as _))
134            }
135            Message::BarrierBatch(_) => Some(permits::Value::Barrier(1)),
136            Message::Watermark(_) => None,
137        };
138
139        if let Some(permits) = &permits {
140            if self.permits.acquire_permits(permits).await.is_err() {
141                return Err(mpsc::error::SendError(message));
142            }
143        }
144
145        self.tx
146            .send(MessageWithPermits { message, permits })
147            .map_err(|e| mpsc::error::SendError(e.0.message))
148    }
149}
150
151/// The receiver of the exchange service with permit-based back-pressure.
152pub struct Receiver {
153    rx: mpsc::UnboundedReceiver<MessageWithPermits>,
154    permits: Arc<Permits>,
155}
156
157impl Receiver {
158    /// Receive the next message for this receiver, with the permits of this message added back.
159    /// Used for local exchange.
160    ///
161    /// Returns `None` if the channel has been closed.
162    pub async fn recv(&mut self) -> Option<Message> {
163        let MessageWithPermits { message, permits } = self.recv_raw().await?;
164
165        if let Some(permits) = permits {
166            self.permits.add_permits(permits);
167        }
168
169        Some(message)
170    }
171
172    /// Try to receive the next message for this receiver, with the permits of this message added
173    /// back.
174    ///
175    /// Returns error if the channel is currently empty.
176    pub fn try_recv(&mut self) -> Result<Message, mpsc::error::TryRecvError> {
177        let MessageWithPermits { message, permits } = self.rx.try_recv()?;
178
179        if let Some(permits) = permits {
180            self.permits.add_permits(permits);
181        }
182
183        Ok(message)
184    }
185
186    /// Receive the next message and its permits for this receiver, **without** adding the permits
187    /// back. Used for remote exchange where the permits should be manually added according to the
188    /// downstream actor.
189    ///
190    /// Returns `None` if the channel has been closed.
191    pub async fn recv_raw(&mut self) -> Option<MessageWithPermits> {
192        self.rx.recv().await
193    }
194
195    /// Get a reference to the inner [`Permits`] to manually add permits.
196    pub fn permits(&self) -> Arc<Permits> {
197        self.permits.clone()
198    }
199}
200
201impl Drop for Receiver {
202    fn drop(&mut self) {
203        // Close the `permits` semaphores so that all pending `acquire` on the sender side will fail
204        // immediately.
205        self.permits.close();
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use std::assert_matches::assert_matches;
212    use std::pin::pin;
213
214    use futures::FutureExt;
215
216    use super::*;
217    use crate::executor::DispatcherBarrier as Barrier;
218
219    #[test]
220    fn test_channel_close() {
221        let (tx, mut rx) = channel(0, 0, 1);
222
223        let send = || {
224            tx.send(Message::BarrierBatch(vec![
225                Barrier::with_prev_epoch_for_test(514, 114),
226            ]))
227        };
228
229        assert_matches!(send().now_or_never(), Some(Ok(_))); // send successfully
230        assert_matches!(
231            rx.recv().now_or_never(),
232            Some(Some(Message::BarrierBatch(_)))
233        ); // recv successfully
234
235        assert_matches!(send().now_or_never(), Some(Ok(_))); // send successfully
236        // do not recv, so that the channel is full
237
238        let mut send_fut = pin!(send());
239        assert_matches!((&mut send_fut).now_or_never(), None); // would block due to no permits
240        drop(rx);
241        assert_matches!(send_fut.now_or_never(), Some(Err(_))); // channel closed
242    }
243}