risingwave_stream/executor/exchange/
permit.rs1use std::sync::Arc;
18
19use risingwave_pb::task_service::permits;
20use tokio::sync::{AcquireError, Semaphore, SemaphorePermit, mpsc};
21
22use crate::executor::DispatcherMessageBatch as Message;
23
24pub struct MessageWithPermits {
30 pub message: Message,
31 pub permits: Option<permits::Value>,
32}
33
34pub fn channel(
36 initial_permits: usize,
37 batched_permits: usize,
38 concurrent_barriers: usize,
39) -> (Sender, Receiver) {
40 let (tx, rx) = mpsc::unbounded_channel();
42
43 let records = Semaphore::new(initial_permits);
44 let barriers = Semaphore::new(concurrent_barriers);
45 let permits = Arc::new(Permits { records, barriers });
46
47 let max_chunk_permits: usize = initial_permits - batched_permits;
48
49 (
50 Sender {
51 tx,
52 permits: permits.clone(),
53 max_chunk_permits,
54 },
55 Receiver { rx, permits },
56 )
57}
58
59pub mod for_test {
61 pub const INITIAL_PERMITS: usize = (u32::MAX / 2) as _;
62 pub const BATCHED_PERMITS: usize = 1;
63 pub const CONCURRENT_BARRIERS: usize = (u32::MAX / 2) as _;
64}
65
66pub fn channel_for_test() -> (Sender, Receiver) {
67 use for_test::*;
68
69 channel(INITIAL_PERMITS, BATCHED_PERMITS, CONCURRENT_BARRIERS)
70}
71
72pub struct Permits {
76 records: Semaphore,
78 barriers: Semaphore,
80}
81
82impl Permits {
83 pub fn add_permits(&self, permits: permits::Value) {
85 match permits {
86 permits::Value::Record(p) => self.records.add_permits(p as usize),
87 permits::Value::Barrier(p) => self.barriers.add_permits(p as usize),
88 }
89 }
90
91 async fn acquire_permits(&self, permits: &permits::Value) -> Result<(), AcquireError> {
95 match permits {
96 permits::Value::Record(p) => self.records.acquire_many(*p as _),
97 permits::Value::Barrier(p) => self.barriers.acquire_many(*p as _),
98 }
99 .await
100 .map(SemaphorePermit::forget)
101 }
102
103 fn close(&self) {
105 self.records.close();
106 self.barriers.close();
107 }
108}
109
110pub struct Sender {
112 tx: mpsc::UnboundedSender<MessageWithPermits>,
113 permits: Arc<Permits>,
114
115 max_chunk_permits: usize,
119}
120
121impl Sender {
122 pub async fn send(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
126 let permits = match &message {
128 Message::Chunk(c) => {
129 let card = c.cardinality().clamp(1, self.max_chunk_permits);
130 if card == self.max_chunk_permits {
131 tracing::warn!(cardinality = c.cardinality(), "large chunk in exchange")
132 }
133 Some(permits::Value::Record(card as _))
134 }
135 Message::BarrierBatch(_) => Some(permits::Value::Barrier(1)),
136 Message::Watermark(_) => None,
137 };
138
139 if let Some(permits) = &permits {
140 if self.permits.acquire_permits(permits).await.is_err() {
141 return Err(mpsc::error::SendError(message));
142 }
143 }
144
145 self.tx
146 .send(MessageWithPermits { message, permits })
147 .map_err(|e| mpsc::error::SendError(e.0.message))
148 }
149}
150
151pub struct Receiver {
153 rx: mpsc::UnboundedReceiver<MessageWithPermits>,
154 permits: Arc<Permits>,
155}
156
157impl Receiver {
158 pub async fn recv(&mut self) -> Option<Message> {
163 let MessageWithPermits { message, permits } = self.recv_raw().await?;
164
165 if let Some(permits) = permits {
166 self.permits.add_permits(permits);
167 }
168
169 Some(message)
170 }
171
172 pub fn try_recv(&mut self) -> Result<Message, mpsc::error::TryRecvError> {
177 let MessageWithPermits { message, permits } = self.rx.try_recv()?;
178
179 if let Some(permits) = permits {
180 self.permits.add_permits(permits);
181 }
182
183 Ok(message)
184 }
185
186 pub async fn recv_raw(&mut self) -> Option<MessageWithPermits> {
192 self.rx.recv().await
193 }
194
195 pub fn permits(&self) -> Arc<Permits> {
197 self.permits.clone()
198 }
199}
200
201impl Drop for Receiver {
202 fn drop(&mut self) {
203 self.permits.close();
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use std::assert_matches::assert_matches;
212 use std::pin::pin;
213
214 use futures::FutureExt;
215
216 use super::*;
217 use crate::executor::DispatcherBarrier as Barrier;
218
219 #[test]
220 fn test_channel_close() {
221 let (tx, mut rx) = channel(0, 0, 1);
222
223 let send = || {
224 tx.send(Message::BarrierBatch(vec![
225 Barrier::with_prev_epoch_for_test(514, 114),
226 ]))
227 };
228
229 assert_matches!(send().now_or_never(), Some(Ok(_))); assert_matches!(
231 rx.recv().now_or_never(),
232 Some(Some(Message::BarrierBatch(_)))
233 ); assert_matches!(send().now_or_never(), Some(Ok(_))); let mut send_fut = pin!(send());
239 assert_matches!((&mut send_fut).now_or_never(), None); drop(rx);
241 assert_matches!(send_fut.now_or_never(), Some(Err(_))); }
243}