risingwave_frontend/binder/
bind_param.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::Bytes;
16use pgwire::types::{Format, FormatIterator};
17use risingwave_common::bail;
18use risingwave_common::error::BoxedError;
19use risingwave_common::types::{Datum, ScalarImpl};
20
21use super::BoundStatement;
22use super::statement::RewriteExprsRecursive;
23use crate::error::{ErrorCode, Result};
24use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal, default_rewrite_expr};
25
26/// Rewrites parameter expressions to literals.
27pub(crate) struct ParamRewriter {
28    pub(crate) params: Vec<Option<Bytes>>,
29    pub(crate) parsed_params: Vec<Datum>,
30    pub(crate) param_formats: Vec<Format>,
31    pub(crate) error: Option<BoxedError>,
32}
33
34impl ParamRewriter {
35    pub(crate) fn new(param_formats: Vec<Format>, params: Vec<Option<Bytes>>) -> Self {
36        Self {
37            parsed_params: vec![None; params.len()],
38            params,
39            param_formats,
40            error: None,
41        }
42    }
43}
44
45impl ExprRewriter for ParamRewriter {
46    fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
47        if self.error.is_some() {
48            return expr;
49        }
50        default_rewrite_expr(self, expr)
51    }
52
53    fn rewrite_subquery(&mut self, mut subquery: crate::expr::Subquery) -> ExprImpl {
54        subquery.query.rewrite_exprs_recursive(self);
55        subquery.into()
56    }
57
58    fn rewrite_parameter(&mut self, parameter: crate::expr::Parameter) -> ExprImpl {
59        let data_type = parameter.return_type();
60
61        // Postgresql parameter index is 1-based. e.g. $1,$2,$3
62        // But we store it in 0-based vector. So we need to minus 1.
63        let parameter_index = (parameter.index - 1) as usize;
64
65        fn cstr_to_str(b: &[u8]) -> std::result::Result<&str, BoxedError> {
66            let without_null = if b.last() == Some(&0) {
67                &b[..b.len() - 1]
68            } else {
69                b
70            };
71            Ok(std::str::from_utf8(without_null)?)
72        }
73
74        let datum: Datum = if let Some(val_bytes) = self.params[parameter_index].clone() {
75            let res = match self.param_formats[parameter_index] {
76                Format::Text => {
77                    cstr_to_str(&val_bytes).and_then(|str| ScalarImpl::from_text(str, &data_type))
78                }
79                Format::Binary => ScalarImpl::from_binary(&val_bytes, &data_type),
80            };
81            match res {
82                Ok(datum) => Some(datum),
83                Err(e) => {
84                    self.error = Some(e);
85                    return parameter.into();
86                }
87            }
88        } else {
89            None
90        };
91
92        self.parsed_params[parameter_index].clone_from(&datum);
93        Literal::new(datum, data_type).into()
94    }
95}
96
97impl BoundStatement {
98    pub fn bind_parameter(
99        mut self,
100        params: Vec<Option<Bytes>>,
101        param_formats: Vec<Format>,
102    ) -> Result<(BoundStatement, Vec<Datum>)> {
103        let mut rewriter = ParamRewriter::new(
104            FormatIterator::new(&param_formats, params.len())
105                .map_err(ErrorCode::BindError)?
106                .collect(),
107            params,
108        );
109
110        self.rewrite_exprs_recursive(&mut rewriter);
111
112        if let Some(err) = rewriter.error {
113            bail!(err);
114        }
115
116        Ok((self, rewriter.parsed_params))
117    }
118}
119
120#[cfg(test)]
121mod test {
122    use bytes::Bytes;
123    use pgwire::types::Format;
124    use risingwave_common::types::DataType;
125    use risingwave_sqlparser::test_utils::parse_sql_statements;
126
127    use crate::binder::BoundStatement;
128    use crate::binder::test_utils::{mock_binder, mock_binder_with_param_types};
129
130    fn create_expect_bound(sql: &str) -> BoundStatement {
131        let mut binder = mock_binder();
132        let stmt = parse_sql_statements(sql).unwrap().remove(0);
133        binder.bind(stmt).unwrap()
134    }
135
136    fn create_actual_bound(
137        sql: &str,
138        param_types: Vec<Option<DataType>>,
139        params: Vec<Option<Bytes>>,
140        param_formats: Vec<Format>,
141    ) -> BoundStatement {
142        let mut binder = mock_binder_with_param_types(param_types);
143        let stmt = parse_sql_statements(sql).unwrap().remove(0);
144        let bound = binder.bind(stmt).unwrap();
145        bound.bind_parameter(params, param_formats).unwrap().0
146    }
147
148    fn expect_actual_eq(expect: BoundStatement, actual: BoundStatement) {
149        // Use debug format to compare. May modify in future.
150        assert_eq!(format!("{:?}", expect), format!("{:?}", actual));
151    }
152
153    #[tokio::test]
154    async fn basic_select() {
155        expect_actual_eq(
156            create_expect_bound("select 1::int4"),
157            create_actual_bound(
158                "select $1::int4",
159                vec![],
160                vec![Some("1".into())],
161                vec![Format::Text],
162            ),
163        );
164    }
165
166    #[tokio::test]
167    async fn basic_value() {
168        expect_actual_eq(
169            create_expect_bound("values(1::int4)"),
170            create_actual_bound(
171                "values($1::int4)",
172                vec![],
173                vec![Some("1".into())],
174                vec![Format::Text],
175            ),
176        );
177    }
178
179    #[tokio::test]
180    async fn default_type() {
181        expect_actual_eq(
182            create_expect_bound("select '1'"),
183            create_actual_bound(
184                "select $1",
185                vec![],
186                vec![Some("1".into())],
187                vec![Format::Text],
188            ),
189        );
190    }
191
192    #[tokio::test]
193    async fn cast_after_specific() {
194        expect_actual_eq(
195            create_expect_bound("select 1::varchar"),
196            create_actual_bound(
197                "select $1::varchar",
198                vec![Some(DataType::Int32)],
199                vec![Some("1".into())],
200                vec![Format::Text],
201            ),
202        );
203    }
204
205    #[tokio::test]
206    async fn infer_case() {
207        expect_actual_eq(
208            create_expect_bound("select 1,1::INT4"),
209            create_actual_bound(
210                "select $1,$1::INT4",
211                vec![],
212                vec![Some("1".into())],
213                vec![Format::Text],
214            ),
215        );
216    }
217
218    #[tokio::test]
219    async fn subquery() {
220        expect_actual_eq(
221            create_expect_bound("select (select '1')"),
222            create_actual_bound(
223                "select (select $1)",
224                vec![],
225                vec![Some("1".into())],
226                vec![Format::Text],
227            ),
228        );
229    }
230}