risingwave_stream/executor/exchange/
permit.rs1use std::sync::Arc;
18
19use risingwave_common::config::StreamingConfig;
20use risingwave_pb::task_service::permits;
21use tokio::sync::{AcquireError, Semaphore, SemaphorePermit, mpsc};
22
23use crate::executor::DispatcherMessageBatch as Message;
24
25pub struct MessageWithPermits {
31 pub message: Message,
32 pub permits: Option<permits::Value>,
33}
34
35pub fn channel(
37 initial_permits: usize,
38 batched_permits: usize,
39 concurrent_barriers: usize,
40) -> (Sender, Receiver) {
41 let (tx, rx) = mpsc::unbounded_channel();
43
44 let records = Semaphore::new(initial_permits);
45 let barriers = Semaphore::new(concurrent_barriers);
46 let permits = Arc::new(Permits { records, barriers });
47
48 let max_chunk_permits: usize = initial_permits - batched_permits;
49
50 (
51 Sender {
52 tx,
53 permits: permits.clone(),
54 max_chunk_permits,
55 },
56 Receiver { rx, permits },
57 )
58}
59
60pub fn channel_from_config(config: &StreamingConfig) -> (Sender, Receiver) {
61 channel(
62 config.developer.exchange_initial_permits,
63 config.developer.exchange_batched_permits,
64 config.developer.exchange_concurrent_barriers,
65 )
66}
67
68pub mod for_test {
70 pub const INITIAL_PERMITS: usize = (u32::MAX / 2) as _;
71 pub const BATCHED_PERMITS: usize = 1;
72 pub const CONCURRENT_BARRIERS: usize = (u32::MAX / 2) as _;
73}
74
75pub fn channel_for_test() -> (Sender, Receiver) {
76 use for_test::*;
77
78 channel(INITIAL_PERMITS, BATCHED_PERMITS, CONCURRENT_BARRIERS)
79}
80
81pub struct Permits {
85 records: Semaphore,
87 barriers: Semaphore,
89}
90
91impl Permits {
92 pub fn add_permits(&self, permits: permits::Value) {
94 match permits {
95 permits::Value::Record(p) => self.records.add_permits(p as usize),
96 permits::Value::Barrier(p) => self.barriers.add_permits(p as usize),
97 }
98 }
99
100 async fn acquire_permits(&self, permits: &permits::Value) -> Result<(), AcquireError> {
104 match permits {
105 permits::Value::Record(p) => self.records.acquire_many(*p as _),
106 permits::Value::Barrier(p) => self.barriers.acquire_many(*p as _),
107 }
108 .await
109 .map(SemaphorePermit::forget)
110 }
111
112 fn close(&self) {
114 self.records.close();
115 self.barriers.close();
116 }
117}
118
119pub struct Sender {
121 tx: mpsc::UnboundedSender<MessageWithPermits>,
122 permits: Arc<Permits>,
123
124 max_chunk_permits: usize,
128}
129
130impl Sender {
131 pub async fn send(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
135 let permits = match &message {
137 Message::Chunk(c) => {
138 let card = c.cardinality().clamp(1, self.max_chunk_permits);
139 if card == self.max_chunk_permits {
140 tracing::warn!(cardinality = c.cardinality(), "large chunk in exchange")
141 }
142 Some(permits::Value::Record(card as _))
143 }
144 Message::BarrierBatch(_) => Some(permits::Value::Barrier(1)),
145 Message::Watermark(_) => None,
146 };
147
148 if let Some(permits) = &permits {
149 if self.permits.acquire_permits(permits).await.is_err() {
150 return Err(mpsc::error::SendError(message));
151 }
152 }
153
154 self.tx
155 .send(MessageWithPermits { message, permits })
156 .map_err(|e| mpsc::error::SendError(e.0.message))
157 }
158}
159
160pub struct Receiver {
162 rx: mpsc::UnboundedReceiver<MessageWithPermits>,
163 permits: Arc<Permits>,
164}
165
166impl Receiver {
167 pub async fn recv(&mut self) -> Option<Message> {
172 let MessageWithPermits { message, permits } = self.recv_raw().await?;
173
174 if let Some(permits) = permits {
175 self.permits.add_permits(permits);
176 }
177
178 Some(message)
179 }
180
181 pub fn try_recv(&mut self) -> Result<Message, mpsc::error::TryRecvError> {
186 let MessageWithPermits { message, permits } = self.rx.try_recv()?;
187
188 if let Some(permits) = permits {
189 self.permits.add_permits(permits);
190 }
191
192 Ok(message)
193 }
194
195 pub async fn recv_raw(&mut self) -> Option<MessageWithPermits> {
201 self.rx.recv().await
202 }
203
204 pub fn permits(&self) -> Arc<Permits> {
206 self.permits.clone()
207 }
208}
209
210impl Drop for Receiver {
211 fn drop(&mut self) {
212 self.permits.close();
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::assert_matches::assert_matches;
221 use std::pin::pin;
222
223 use futures::FutureExt;
224
225 use super::*;
226 use crate::executor::DispatcherBarrier as Barrier;
227
228 #[test]
229 fn test_channel_close() {
230 let (tx, mut rx) = channel(0, 0, 1);
231
232 let send = || {
233 tx.send(Message::BarrierBatch(vec![
234 Barrier::with_prev_epoch_for_test(514, 114),
235 ]))
236 };
237
238 assert_matches!(send().now_or_never(), Some(Ok(_))); assert_matches!(
240 rx.recv().now_or_never(),
241 Some(Some(Message::BarrierBatch(_)))
242 ); assert_matches!(send().now_or_never(), Some(Ok(_))); let mut send_fut = pin!(send());
248 assert_matches!((&mut send_fut).now_or_never(), None); drop(rx);
250 assert_matches!(send_fut.now_or_never(), Some(Err(_))); }
252}