risingwave_dml/
txn_channel.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use futures::FutureExt;
18use risingwave_common::transaction::transaction_message::TxnMsg;
19use tokio::sync::{Semaphore, mpsc, oneshot};
20
21pub struct PermitValue(u32);
22
23pub struct TxnMsgWithPermits {
24    pub txn_msg: TxnMsg,
25    pub notificator: oneshot::Sender<usize>,
26    pub permit_value: Option<PermitValue>,
27}
28
29/// Create a channel for transaction messages.
30pub fn txn_channel(max_chunk_permits: usize) -> (Sender, Receiver) {
31    // Use an unbounded channel since we manage the permits manually.
32    let (tx, rx) = mpsc::unbounded_channel();
33
34    let records = Semaphore::new(max_chunk_permits);
35    let permits = Arc::new(Permits { records });
36
37    (
38        Sender {
39            tx,
40            permits: permits.clone(),
41            max_chunk_permits,
42        },
43        Receiver { rx, permits },
44    )
45}
46
47/// Semaphore-based permits to control the back-pressure.
48///
49/// The number of messages in the transaction channel is limited by these semaphores.
50#[derive(Debug)]
51pub struct Permits {
52    /// The permits for records in chunks.
53    records: Semaphore,
54}
55
56impl Permits {
57    /// Add permits back to the semaphores.
58    pub fn add_permits(&self, permit_value: PermitValue) {
59        self.records.add_permits(permit_value.0 as usize)
60    }
61}
62
63/// The sender of the transaction channel with permit-based back-pressure.
64#[derive(Debug, Clone)]
65pub struct Sender {
66    pub tx: mpsc::UnboundedSender<TxnMsgWithPermits>,
67    permits: Arc<Permits>,
68
69    /// The maximum permits required by a chunk. If there're too many rows in a chunk, we only
70    /// acquire these permits.
71    max_chunk_permits: usize,
72}
73
74impl Sender {
75    /// Send a message, waiting until there are enough permits.
76    /// Used to send transaction data messages.
77    ///
78    /// Returns error if the receive half of the channel is closed, including the message passed.
79    pub async fn send(
80        &self,
81        txn_msg: TxnMsg,
82        notificator: oneshot::Sender<usize>,
83    ) -> Result<(), mpsc::error::SendError<TxnMsg>> {
84        // The semaphores should never be closed.
85        let permits = match &txn_msg {
86            TxnMsg::Data(_, c) => {
87                let card = c.cardinality().clamp(1, self.max_chunk_permits);
88                if card == self.max_chunk_permits {
89                    tracing::warn!(
90                        cardinality = c.cardinality(),
91                        "large chunk in transaction channel"
92                    )
93                }
94                self.permits
95                    .records
96                    .acquire_many(card as _)
97                    .await
98                    .unwrap()
99                    .forget();
100                Some(PermitValue(card as _))
101            }
102            TxnMsg::Begin(_) | TxnMsg::Rollback(_) | TxnMsg::End(..) => None,
103        };
104
105        self.tx
106            .send(TxnMsgWithPermits {
107                txn_msg,
108                notificator,
109                permit_value: permits,
110            })
111            .map_err(|e| mpsc::error::SendError(e.0.txn_msg))
112    }
113
114    /// Send a message without permit acquiring.
115    /// Used to send transaction control messages.
116    ///
117    /// Returns error if the receive half of the channel is closed, including the message passed.
118    pub fn send_immediate(
119        &self,
120        txn_msg: TxnMsg,
121        notificator: oneshot::Sender<usize>,
122    ) -> Result<(), mpsc::error::SendError<TxnMsg>> {
123        self.send(txn_msg, notificator)
124            .now_or_never()
125            .expect("cannot send immediately")
126    }
127
128    pub fn is_closed(&self) -> bool {
129        self.tx.is_closed()
130    }
131}
132
133/// The receiver of the txn channel with permit-based back-pressure.
134#[derive(Debug)]
135pub struct Receiver {
136    rx: mpsc::UnboundedReceiver<TxnMsgWithPermits>,
137    permits: Arc<Permits>,
138}
139
140impl Receiver {
141    /// Receive the next message for this receiver, with the permits of this message added back.
142    ///
143    /// Returns `None` if the channel has been closed.
144    pub async fn recv(&mut self) -> Option<(TxnMsg, oneshot::Sender<usize>)> {
145        let TxnMsgWithPermits {
146            txn_msg,
147            notificator,
148            permit_value: permits,
149        } = self.rx.recv().await?;
150
151        if let Some(permits) = permits {
152            self.permits.add_permits(permits);
153        }
154
155        Some((txn_msg, notificator))
156    }
157}